nhatminh commited on
Commit
4aed4b4
·
verified ·
1 Parent(s): 092552a

Upload 12 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
block.py ADDED
@@ -0,0 +1,470 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This implementation was adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/block.py
2
+ # Commit id: abbc1311731867310635f9edc2a9ec18317c8c48
3
+
4
+ # Copyright (c) 2024, Tri Dao.
5
+
6
+ from functools import partial
7
+ from typing import Optional
8
+
9
+ import torch
10
+ import torch.fx
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ from torch import Tensor
14
+
15
+ from .mha import MHA
16
+ from .mlp import Mlp
17
+
18
+ try:
19
+ from flash_attn.ops.triton.layer_norm import layer_norm_fn, RMSNorm
20
+ except ImportError:
21
+ layer_norm_fn, RMSNorm = None, None
22
+
23
+
24
+ def stochastic_depth(
25
+ input: Tensor, p: float, mode: str, training: bool = True
26
+ ) -> Tensor:
27
+ """
28
+ Implements the Stochastic Depth from `"Deep Networks with Stochastic Depth"
29
+ <https://arxiv.org/abs/1603.09382>`_ used for randomly dropping residual
30
+ branches of residual architectures.
31
+ Args:
32
+ input (Tensor[N, ...]): The input tensor or arbitrary dimensions with the first one
33
+ being its batch i.e. a batch with ``N`` rows.
34
+ p (float): probability of the input to be zeroed.
35
+ mode (str): ``"batch"`` or ``"row"``.
36
+ ``"batch"`` randomly zeroes the entire input, ``"row"`` zeroes
37
+ randomly selected rows from the batch.
38
+ training: apply stochastic depth if is ``True``. Default: ``True``
39
+ Returns:
40
+ Tensor[N, ...]: The randomly zeroed tensor.
41
+ """
42
+ if p < 0.0 or p > 1.0:
43
+ raise ValueError(f"drop probability has to be between 0 and 1, but got {p}")
44
+ if mode not in ["batch", "row"]:
45
+ raise ValueError(f"mode has to be either 'batch' or 'row', but got {mode}")
46
+ if not training or p == 0.0:
47
+ return input
48
+
49
+ survival_rate = 1.0 - p
50
+ if mode == "row":
51
+ size = [input.shape[0]] + [1] * (input.ndim - 1)
52
+ else:
53
+ size = [1] * input.ndim
54
+ noise = torch.empty(size, dtype=input.dtype, device=input.device)
55
+ noise = noise.bernoulli_(survival_rate)
56
+ if survival_rate > 0.0:
57
+ noise.div_(survival_rate)
58
+ return input * noise
59
+
60
+
61
+ torch.fx.wrap("stochastic_depth")
62
+
63
+
64
+ class StochasticDepth(nn.Module):
65
+ """
66
+ See :func:`stochastic_depth`.
67
+ """
68
+
69
+ def __init__(self, p: float, mode: str) -> None:
70
+ super().__init__()
71
+ self.p = p
72
+ self.mode = mode
73
+
74
+ def forward(self, input: Tensor) -> Tensor:
75
+ return stochastic_depth(input, self.p, self.mode, self.training)
76
+
77
+ def __repr__(self) -> str:
78
+ s = f"{self.__class__.__name__}(p={self.p}, mode={self.mode})"
79
+ return s
80
+
81
+
82
+ class Block(nn.Module):
83
+ def __init__(
84
+ self,
85
+ dim,
86
+ mixer_cls=None,
87
+ mlp_cls=None,
88
+ norm_cls=nn.LayerNorm,
89
+ dropout_cls=nn.Dropout,
90
+ prenorm=True,
91
+ resid_dropout1=0.0,
92
+ resid_dropout2=0.0,
93
+ drop_path1=0.0,
94
+ drop_path2=0.0,
95
+ fused_dropout_add_ln=False,
96
+ return_residual=False,
97
+ residual_in_fp32=False,
98
+ sequence_parallel=False,
99
+ mark_shared_params=False,
100
+ ):
101
+ """
102
+ For prenorm=True, this Block has a slightly different structure compared to a regular
103
+ prenorm Transformer block.
104
+ The standard block is: LN -> MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add.
105
+ [Ref: https://arxiv.org/abs/2002.04745]
106
+ Here we have: Dropout -> Add -> LN -> MHA -> Dropout -> Add -> LN -> MLP, returning both
107
+ the hidden_states (output of the MLP) and the residual.
108
+ This is for performance reasons, as we can fuse the dropout, add and LayerNorm.
109
+ The residual needs to be provided (except for the very first block).
110
+
111
+ For prenorm=False, this Block has the same structure as a regular postnorm Transformer
112
+ block: MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add -> LN.
113
+
114
+ return_residual: whether each of the sub-layers (mixer and mlp) will return the residual.
115
+ This is for performance reason: for post-norm architecture, returning the input allows us
116
+ to fuse the backward of nn.Linear with the residual connection.
117
+ """
118
+ super().__init__()
119
+ self.prenorm = prenorm
120
+ self.fused_dropout_add_ln = fused_dropout_add_ln
121
+ self.return_residual = return_residual
122
+ self.residual_in_fp32 = residual_in_fp32
123
+ if self.residual_in_fp32:
124
+ assert self.prenorm, "residual_in_fp32 is only compatible with prenorm=True"
125
+ if mixer_cls is None:
126
+ mixer_cls = partial(MHA, num_heads=dim // 64)
127
+ if mlp_cls is None:
128
+ mlp_cls = partial(Mlp, hidden_features=4 * dim)
129
+ self.mixer = mixer_cls(dim)
130
+ self.dropout1 = dropout_cls(resid_dropout1)
131
+ self.drop_path1 = StochasticDepth(drop_path1, mode="row")
132
+ self.norm1 = norm_cls(dim)
133
+ self.mlp = mlp_cls(dim)
134
+ if not isinstance(self.mlp, nn.Identity):
135
+ self.dropout2 = dropout_cls(resid_dropout2)
136
+ self.drop_path2 = StochasticDepth(drop_path2, mode="row")
137
+ self.norm2 = norm_cls(dim)
138
+
139
+ if self.fused_dropout_add_ln:
140
+ assert layer_norm_fn is not None, "Triton is not installed"
141
+ assert isinstance(self.norm1, (nn.LayerNorm, RMSNorm)) and isinstance(
142
+ self.dropout1, nn.Dropout
143
+ )
144
+
145
+ # TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0,
146
+ # then the input to each worker in the tensor parallel group will be different.
147
+ # This would produce wrong outputs? Somehow we'd need to sync the RNG state across workers.
148
+ # For now this is not an issue because we always use sequence_parallel=True during training
149
+ # and only use sequence_parallel=False during inference.
150
+
151
+ # Mark the norm parameters as "sequence_parallel" so that we run all-reduce on their grads.
152
+ if sequence_parallel:
153
+ for p in self.norm1.parameters():
154
+ p._sequence_parallel = True
155
+ if hasattr(self, "norm2"):
156
+ for p in self.norm2.parameters():
157
+ p._sequence_parallel = True
158
+ # Mark the norm parameters as "shared_params" so that we sync their values at init.
159
+ if mark_shared_params:
160
+ for p in self.norm1.parameters():
161
+ p._shared_params = True
162
+ if hasattr(self, "norm2"):
163
+ for p in self.norm2.parameters():
164
+ p._shared_params = True
165
+
166
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
167
+ return self.mixer.allocate_inference_cache(
168
+ batch_size, max_seqlen, dtype=dtype, **kwargs
169
+ )
170
+
171
+ def forward(
172
+ self,
173
+ hidden_states: Tensor,
174
+ residual: Optional[Tensor] = None,
175
+ mixer_subset=None,
176
+ mixer_kwargs=None,
177
+ ):
178
+ r"""Pass the input through the encoder layer.
179
+
180
+ Args:
181
+ hidden_states: the sequence to the encoder layer (required).
182
+ residual: if postnorm, residual=None, If prenorm, hidden_states = Attn/MLP(LN(residual))
183
+ mixer_subset: for cross-attention only. If not None, will take a subset of x
184
+ before applying the query projection. Useful for e.g., ViT where we only care
185
+ about the CLS token in the last layer.
186
+ """
187
+ if self.prenorm:
188
+ if not self.fused_dropout_add_ln:
189
+ dropped = self.drop_path1(self.dropout1(hidden_states))
190
+ residual = (dropped + residual) if residual is not None else dropped
191
+ hidden_states = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
192
+ if self.residual_in_fp32:
193
+ residual = residual.to(torch.float32)
194
+ else:
195
+ if self.drop_path1.p == 0 or not self.training:
196
+ rowscale1 = None
197
+ else:
198
+ rowscale1 = self.drop_path1(
199
+ torch.ones(
200
+ hidden_states.shape[:-1],
201
+ device=hidden_states.device,
202
+ dtype=hidden_states.dtype,
203
+ )
204
+ )
205
+ hidden_states, residual = layer_norm_fn(
206
+ hidden_states,
207
+ self.norm1.weight,
208
+ self.norm1.bias,
209
+ residual=residual,
210
+ eps=self.norm1.eps,
211
+ dropout_p=self.dropout1.p if self.training else 0.0,
212
+ rowscale=rowscale1,
213
+ prenorm=True,
214
+ residual_in_fp32=self.residual_in_fp32,
215
+ is_rms_norm=isinstance(self.norm1, RMSNorm),
216
+ )
217
+ if mixer_kwargs is None:
218
+ mixer_kwargs = {}
219
+ if mixer_subset is not None:
220
+ mixer_kwargs["mixer_subset"] = mixer_subset
221
+ hidden_states = self.mixer(hidden_states, **mixer_kwargs)
222
+ if mixer_subset is not None:
223
+ residual = residual[:, mixer_subset]
224
+ if not isinstance(self.mlp, nn.Identity):
225
+ if not self.fused_dropout_add_ln:
226
+ dropped = self.drop_path2(self.dropout2(hidden_states))
227
+ residual = (dropped + residual) if residual is not None else dropped
228
+ hidden_states = self.norm2(
229
+ residual.to(dtype=self.norm2.weight.dtype)
230
+ )
231
+ if self.residual_in_fp32:
232
+ residual = residual.to(torch.float32)
233
+ else:
234
+ if self.drop_path2.p == 0 or not self.training:
235
+ rowscale2 = None
236
+ else:
237
+ rowscale2 = self.drop_path2(
238
+ torch.ones(
239
+ hidden_states.shape[:-1],
240
+ device=hidden_states.device,
241
+ dtype=hidden_states.dtype,
242
+ )
243
+ )
244
+ hidden_states, residual = layer_norm_fn(
245
+ hidden_states,
246
+ self.norm2.weight,
247
+ self.norm2.bias,
248
+ residual=residual,
249
+ eps=self.norm2.eps,
250
+ dropout_p=self.dropout2.p if self.training else 0.0,
251
+ rowscale=rowscale2,
252
+ prenorm=True,
253
+ residual_in_fp32=self.residual_in_fp32,
254
+ is_rms_norm=isinstance(self.norm2, RMSNorm),
255
+ )
256
+ hidden_states = self.mlp(hidden_states)
257
+ return hidden_states, residual
258
+ else:
259
+ assert residual is None
260
+ mixer_out = self.mixer(
261
+ hidden_states, **(mixer_kwargs if mixer_kwargs is not None else {})
262
+ )
263
+ if self.return_residual: # mixer out is actually a pair here
264
+ mixer_out, hidden_states = mixer_out
265
+ if not self.fused_dropout_add_ln:
266
+ hidden_states = self.norm1(
267
+ (self.drop_path1(self.dropout1(mixer_out)) + hidden_states).to(
268
+ dtype=self.norm1.weight.dtype
269
+ )
270
+ )
271
+ else:
272
+ if self.drop_path1.p == 0 or not self.training:
273
+ rowscale1 = None
274
+ else:
275
+ rowscale1 = self.drop_path1(
276
+ torch.ones(
277
+ mixer_out.shape[:-1],
278
+ device=mixer_out.device,
279
+ dtype=mixer_out.dtype,
280
+ )
281
+ )
282
+ hidden_states = layer_norm_fn(
283
+ mixer_out,
284
+ self.norm1.weight,
285
+ self.norm1.bias,
286
+ residual=hidden_states,
287
+ eps=self.norm1.eps,
288
+ dropout_p=self.dropout1.p if self.training else 0.0,
289
+ rowscale=rowscale1,
290
+ prenorm=False,
291
+ is_rms_norm=isinstance(self.norm1, RMSNorm),
292
+ )
293
+ if not isinstance(self.mlp, nn.Identity):
294
+ mlp_out = self.mlp(hidden_states)
295
+ if self.return_residual: # mlp out is actually a pair here
296
+ mlp_out, hidden_states = mlp_out
297
+ if not self.fused_dropout_add_ln:
298
+ hidden_states = self.norm2(
299
+ (self.drop_path2(self.dropout2(mlp_out)) + hidden_states).to(
300
+ dtype=self.norm2.weight.dtype
301
+ )
302
+ )
303
+ else:
304
+ if self.drop_path2.p == 0 or not self.training:
305
+ rowscale2 = None
306
+ else:
307
+ rowscale2 = self.drop_path2(
308
+ torch.ones(
309
+ mlp_out.shape[:-1],
310
+ device=mlp_out.device,
311
+ dtype=mlp_out.dtype,
312
+ )
313
+ )
314
+ hidden_states = layer_norm_fn(
315
+ mlp_out,
316
+ self.norm2.weight,
317
+ self.norm2.bias,
318
+ residual=hidden_states,
319
+ eps=self.norm2.eps,
320
+ dropout_p=self.dropout2.p if self.training else 0.0,
321
+ rowscale=rowscale2,
322
+ prenorm=False,
323
+ is_rms_norm=isinstance(self.norm2, RMSNorm),
324
+ )
325
+ return hidden_states
326
+
327
+
328
+ class ParallelBlock(nn.Module):
329
+ """The attention (mixer) and MLP blocks are done in parallel, similar to GPT-J, GPT-NeoX,
330
+ and PaLM.
331
+ """
332
+
333
+ def __init__(
334
+ self,
335
+ dim,
336
+ mixer_cls=None,
337
+ mlp_cls=None,
338
+ norm_cls=nn.LayerNorm,
339
+ dropout_cls=nn.Dropout,
340
+ resid_dropout1=0.0,
341
+ resid_dropout2=0.0,
342
+ tied_norm=False,
343
+ fused_dropout_add_ln=False,
344
+ residual_in_fp32=False,
345
+ sequence_parallel=False,
346
+ mark_shared_params=False,
347
+ ):
348
+ """
349
+ This Block has a slightly different structure compared to a regular
350
+ prenorm Transformer block.
351
+ The standard block is: LN -> MHA / MLP -> Dropout -> Add.
352
+ [Ref: https://arxiv.org/abs/2002.04745]
353
+ Here we have: Dropout -> Add -> LN -> MHA / MLP, returning both
354
+ the hidden_states (output1 of the MHA / MLP) and the residual.
355
+ This is for performance reasons, as we can fuse the dropout, add and LayerNorm.
356
+ The residual needs to be provided (except for the very first block).
357
+ """
358
+ super().__init__()
359
+ self.tied_norm = tied_norm
360
+ self.fused_dropout_add_ln = fused_dropout_add_ln
361
+ self.residual_in_fp32 = residual_in_fp32
362
+ if mixer_cls is None:
363
+ mixer_cls = partial(MHA, num_heads=dim // 64)
364
+ if mlp_cls is None:
365
+ mlp_cls = partial(Mlp, hidden_features=4 * dim)
366
+ self.mixer = mixer_cls(dim)
367
+ self.dropout1 = dropout_cls(resid_dropout1)
368
+ self.norm1 = norm_cls(dim)
369
+ self.mlp = mlp_cls(dim)
370
+ self.dropout2 = dropout_cls(resid_dropout2)
371
+ if not self.tied_norm:
372
+ self.norm2 = norm_cls(dim)
373
+
374
+ if self.fused_dropout_add_ln:
375
+ assert layer_norm_fn is not None, "Triton is not installed"
376
+ assert isinstance(self.norm1, (nn.LayerNorm, RMSNorm)) and isinstance(
377
+ self.dropout1, nn.Dropout
378
+ )
379
+
380
+ # TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0,
381
+ # then the input to each worker in the tensor parallel group will be different.
382
+ # This would produce wrong outputs? Somehow we'd need to sync the RNG state across workers.
383
+ # For now this is not an issue because we always use sequence_parallel=True during training
384
+ # and only use sequence_parallel=False during inference.
385
+
386
+ # Mark the norm parameters as "sequence_parallel" so that we run all-reduce on their grads.
387
+ if sequence_parallel:
388
+ for p in self.norm1.parameters():
389
+ p._sequence_parallel = True
390
+ if hasattr(self, "norm2"):
391
+ for p in self.norm2.parameters():
392
+ p._sequence_parallel = True
393
+ # Mark the norm parameters as "shared_params" so that we sync their values at init.
394
+ if mark_shared_params:
395
+ for p in self.norm1.parameters():
396
+ p._shared_params = True
397
+ if hasattr(self, "norm2"):
398
+ for p in self.norm2.parameters():
399
+ p._shared_params = True
400
+
401
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
402
+ return self.mixer.allocate_inference_cache(
403
+ batch_size, max_seqlen, dtype=dtype, **kwargs
404
+ )
405
+
406
+ def forward(
407
+ self,
408
+ hidden_states1: Tensor,
409
+ hidden_states2: Optional[Tensor] = None,
410
+ residual: Optional[Tensor] = None,
411
+ mixer_kwargs=None,
412
+ ):
413
+ r"""Pass the input through the encoder layer.
414
+
415
+ Args:
416
+ hidden_states1: the output of the previous attention (mixer) or embedding layer.
417
+ hidden_states2: the output of the previous MLP layer (if None, will use hidden_states1).
418
+ residual.
419
+ """
420
+ # TODO: Ideally we should only do the allgather / allreduce once for
421
+ # the Linear to MLP & Attention
422
+ if not self.fused_dropout_add_ln:
423
+ dropped1 = self.dropout1(hidden_states1)
424
+ # For the very 1st block, we only want 1 dropout, not two different dropouts
425
+ if hidden_states2 is not None:
426
+ dropped2 = self.dropout2(hidden_states2)
427
+ residual = (
428
+ (residual + dropped1 + dropped2)
429
+ if residual is not None
430
+ else dropped1 + dropped2
431
+ )
432
+ else:
433
+ residual = (residual + dropped1) if residual is not None else dropped1
434
+ hidden_states1 = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
435
+ hidden_states2 = (
436
+ self.norm2(residual.to(dtype=self.norm2.weight.dtype))
437
+ if not self.tied_norm
438
+ else hidden_states1
439
+ )
440
+ if self.residual_in_fp32:
441
+ residual = residual.to(torch.float32)
442
+ else:
443
+ weight2, bias2 = (
444
+ (self.norm2.weight, self.norm2.bias)
445
+ if not self.tied_norm
446
+ else (None, None)
447
+ )
448
+ hidden_states1, *rest, residual = layer_norm_fn(
449
+ hidden_states1,
450
+ self.norm1.weight,
451
+ self.norm1.bias,
452
+ residual=residual,
453
+ x1=hidden_states2,
454
+ weight1=weight2,
455
+ bias1=bias2,
456
+ eps=self.norm1.eps,
457
+ dropout_p=self.dropout1.p if self.training else 0.0,
458
+ prenorm=True,
459
+ residual_in_fp32=self.residual_in_fp32,
460
+ is_rms_norm=isinstance(self.norm1, RMSNorm),
461
+ )
462
+ if self.tied_norm:
463
+ hidden_states2 = hidden_states1
464
+ else:
465
+ (hidden_states2,) = rest
466
+ if mixer_kwargs is None:
467
+ mixer_kwargs = {}
468
+ hidden_states1 = self.mixer(hidden_states1, **mixer_kwargs)
469
+ hidden_states2 = self.mlp(hidden_states2)
470
+ return hidden_states1, hidden_states2, residual
config.json ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "jinaai/jina-reranker-v2-base-multilingual",
3
+ "architectures": ["XLMRobertaForSequenceClassification"],
4
+ "attention_probs_dropout_prob": 0.1,
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_xlm_roberta.XLMRobertaFlashConfig",
7
+ "AutoModel": "modeling_xlm_roberta.XLMRobertaModel",
8
+ "AutoModelForSequenceClassification": "modeling_xlm_roberta.XLMRobertaForSequenceClassification"
9
+ },
10
+ "bos_token_id": 0,
11
+ "classifier_dropout": null,
12
+ "emb_pooler": null,
13
+ "eos_token_id": 2,
14
+ "hidden_act": "gelu",
15
+ "hidden_dropout_prob": 0.1,
16
+ "hidden_size": 768,
17
+ "num_labels": 1,
18
+ "id2label": {
19
+ "0": "LABEL_0"
20
+ },
21
+ "initializer_range": 0.02,
22
+ "intermediate_size": 3072,
23
+ "label2id": {
24
+ "LABEL_0": 0
25
+ },
26
+ "layer_norm_eps": 1e-5,
27
+ "max_position_embeddings": 1026,
28
+ "num_attention_heads": 12,
29
+ "num_hidden_layers": 12,
30
+ "output_past": true,
31
+ "pad_token_id": 1,
32
+ "position_embedding_type": "absolute",
33
+ "torch_dtype": "bfloat16",
34
+ "transformers_version": "4.40.0",
35
+ "type_vocab_size": 1,
36
+ "use_cache": false,
37
+ "use_flash_attn": true,
38
+ "vocab_size": 250002
39
+ }
configuration_xlm_roberta.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+ import torch
3
+
4
+ class XLMRobertaFlashConfig(PretrainedConfig):
5
+ def __init__(
6
+ self,
7
+ vocab_size=30522,
8
+ hidden_size=768,
9
+ num_hidden_layers=12,
10
+ num_attention_heads=12,
11
+ intermediate_size=3072,
12
+ hidden_act="gelu",
13
+ hidden_dropout_prob=0.1,
14
+ attention_probs_dropout_prob=0.1,
15
+ max_position_embeddings=512,
16
+ type_vocab_size=2,
17
+ initializer_range=0.02,
18
+ layer_norm_eps=1e-12,
19
+ pad_token_id=1,
20
+ bos_token_id=0,
21
+ eos_token_id=2,
22
+ position_embedding_type="absolute",
23
+ use_cache=True,
24
+ classifier_dropout=None,
25
+ lora_adaptations=None,
26
+ lora_rank=4,
27
+ lora_dropout_p=0.0,
28
+ lora_alpha=1,
29
+ lora_main_params_trainable=False,
30
+ load_trained_adapters=False,
31
+ use_flash_attn=True,
32
+ torch_dtype=None,
33
+ emb_pooler=None,
34
+ matryoshka_dimensions=None,
35
+ truncate_dim=None,
36
+ **kwargs,
37
+ ):
38
+ super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
39
+
40
+
41
+ self.vocab_size = vocab_size
42
+ self.hidden_size = hidden_size
43
+ self.num_hidden_layers = num_hidden_layers
44
+ self.num_attention_heads = num_attention_heads
45
+ self.hidden_act = hidden_act
46
+ self.intermediate_size = intermediate_size
47
+ self.hidden_dropout_prob = hidden_dropout_prob
48
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
49
+ self.max_position_embeddings = max_position_embeddings
50
+ self.type_vocab_size = type_vocab_size
51
+ self.initializer_range = initializer_range
52
+ self.layer_norm_eps = layer_norm_eps
53
+ self.position_embedding_type = position_embedding_type
54
+ self.use_cache = use_cache
55
+ self.classifier_dropout = classifier_dropout
56
+ self.load_trained_adapters = load_trained_adapters
57
+ self.lora_adaptations = lora_adaptations
58
+ self.lora_rank = lora_rank
59
+ self.lora_dropout_p = lora_dropout_p
60
+ self.lora_alpha = lora_alpha
61
+ self.lora_main_params_trainable = lora_main_params_trainable
62
+ self.use_flash_attn = use_flash_attn
63
+ self.emb_pooler = emb_pooler
64
+ self.matryoshka_dimensions = matryoshka_dimensions
65
+ self.truncate_dim = truncate_dim
66
+ if torch_dtype and hasattr(torch, torch_dtype) and type(getattr(torch, torch_dtype)) is torch.dtype:
67
+ self.torch_dtype = getattr(torch, torch_dtype)
68
+ else:
69
+ self.torch_dtype = torch_dtype
embedding.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This implementation was adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/embedding.py
2
+ # Commit id: f1a73d074002226c42ce65a1df170ecff9f022c0
3
+
4
+ # Copyright (c) 2022, Tri Dao.
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from einops import rearrange
9
+ from torch import Tensor
10
+
11
+ from transformers.models.xlm_roberta.modeling_xlm_roberta import create_position_ids_from_input_ids
12
+
13
+
14
+ class XLMRobertaEmbeddings(nn.Module):
15
+ def __init__(
16
+ self,
17
+ embed_dim,
18
+ vocab_size,
19
+ max_position_embeddings,
20
+ type_vocab_size,
21
+ padding_idx=None,
22
+ device=None,
23
+ dtype=None,
24
+ ):
25
+ """
26
+ If max_position_embeddings <= 0, there's no position embeddings
27
+ If type_vocab_size <= 0, there's no token type embeddings
28
+ """
29
+ factory_kwargs = {"device": device, "dtype": dtype}
30
+ super().__init__()
31
+ self.word_embeddings = nn.Embedding(
32
+ vocab_size, embed_dim, padding_idx=padding_idx, **factory_kwargs
33
+ )
34
+ self.max_position_embeddings = max_position_embeddings
35
+ self.type_vocab_size = type_vocab_size
36
+ if self.max_position_embeddings > 0:
37
+ self.position_embeddings = nn.Embedding(
38
+ max_position_embeddings, embed_dim, **factory_kwargs
39
+ )
40
+ if self.type_vocab_size > 0:
41
+ self.token_type_embeddings = nn.Embedding(type_vocab_size, embed_dim, **factory_kwargs)
42
+
43
+ def forward(self, input_ids, position_ids=None, token_type_ids=None):
44
+ """
45
+ input_ids: (batch, seqlen)
46
+ position_ids: (batch, seqlen)
47
+ token_type_ids: (batch, seqlen)
48
+ """
49
+ batch_size, seqlen = input_ids.shape
50
+ embeddings = self.word_embeddings(input_ids)
51
+ if self.max_position_embeddings > 0:
52
+ if position_ids is None:
53
+ position_ids = create_position_ids_from_input_ids(input_ids, padding_idx=self.word_embeddings.padding_idx).to(input_ids.device)
54
+ # position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device)
55
+ position_embeddings = self.position_embeddings(position_ids)
56
+ embeddings = embeddings + position_embeddings
57
+ if self.type_vocab_size > 0:
58
+ if token_type_ids is None:
59
+ token_type_ids = torch.zeros(seqlen, dtype=torch.long, device=input_ids.device)
60
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
61
+ embeddings = embeddings + token_type_embeddings
62
+ return embeddings
mha.py ADDED
@@ -0,0 +1,662 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Tri Dao.
2
+ # Adapted from https://github.com/Dao-AILab/flash-attention/pull/556
3
+
4
+ import math
5
+ from functools import partial
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from einops import rearrange, repeat
10
+
11
+ try:
12
+ from flash_attn import (
13
+ flash_attn_kvpacked_func,
14
+ flash_attn_qkvpacked_func,
15
+ flash_attn_varlen_kvpacked_func,
16
+ flash_attn_varlen_qkvpacked_func,
17
+ flash_attn_with_kvcache,
18
+ )
19
+ except ImportError:
20
+ flash_attn_varlen_qkvpacked_func, flash_attn_varlen_kvpacked_func = None, None
21
+ flash_attn_qkvpacked_func, flash_attn_kvpacked_func = None, None
22
+ flash_attn_with_kvcache = None
23
+
24
+ try:
25
+ from flash_attn.ops.fused_dense import ColumnParallelLinear, FusedDense, RowParallelLinear
26
+ except ImportError:
27
+ FusedDense, ColumnParallelLinear, RowParallelLinear = None, None, None
28
+
29
+
30
+ class FlashSelfAttention(nn.Module):
31
+ """Implement the scaled dot product attention with softmax.
32
+ Arguments
33
+ ---------
34
+ softmax_scale: The temperature to use for the softmax attention.
35
+ (default: 1/sqrt(d_keys) where d_keys is computed at
36
+ runtime)
37
+ attention_dropout: The dropout rate to apply to the attention
38
+ (default: 0.0)
39
+ """
40
+
41
+ def __init__(
42
+ self,
43
+ causal=False,
44
+ softmax_scale=None,
45
+ attention_dropout=0.0,
46
+ window_size=(-1, -1),
47
+ deterministic=False,
48
+ ):
49
+ super().__init__()
50
+ assert flash_attn_varlen_qkvpacked_func is not None, "FlashAttention is not installed"
51
+ assert flash_attn_qkvpacked_func is not None, "FlashAttention is not installed"
52
+ self.causal = causal
53
+ self.softmax_scale = softmax_scale
54
+ self.drop = nn.Dropout(attention_dropout)
55
+ self.window_size = window_size
56
+ self.deterministic = deterministic
57
+
58
+ def forward(self, qkv, causal=None, cu_seqlens=None, max_seqlen=None):
59
+ """Implements the multihead softmax attention.
60
+ Arguments
61
+ ---------
62
+ qkv: The tensor containing the query, key, and value.
63
+ If cu_seqlens is None and max_seqlen is None, then qkv has shape (B, S, 3, H, D).
64
+ If cu_seqlens is not None and max_seqlen is not None, then qkv has shape
65
+ (total, 3, H, D), where total is the sum of the sequence lengths in the batch.
66
+ causal: if passed, will override self.causal
67
+ cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
68
+ of the sequences in the batch, used to index into qkv.
69
+ max_seqlen: int. Maximum sequence length in the batch.
70
+ Returns:
71
+ --------
72
+ out: (total, H, D) if cu_seqlens is not None and max_seqlen is not None,
73
+ else (B, S, H, D).
74
+ """
75
+ assert qkv.dtype in [torch.float16, torch.bfloat16]
76
+ assert qkv.is_cuda
77
+ causal = self.causal if causal is None else causal
78
+ unpadded = cu_seqlens is not None
79
+
80
+ if unpadded:
81
+ assert cu_seqlens.dtype == torch.int32
82
+ assert max_seqlen is not None
83
+ assert isinstance(max_seqlen, int)
84
+ return flash_attn_varlen_qkvpacked_func(
85
+ qkv,
86
+ cu_seqlens,
87
+ max_seqlen,
88
+ self.drop.p if self.training else 0.0,
89
+ softmax_scale=self.softmax_scale,
90
+ causal=causal,
91
+ alibi_slopes=None,
92
+ window_size=self.window_size,
93
+ deterministic=self.deterministic,
94
+ )
95
+ else:
96
+ return flash_attn_qkvpacked_func(
97
+ qkv,
98
+ self.drop.p if self.training else 0.0,
99
+ softmax_scale=self.softmax_scale,
100
+ causal=causal,
101
+ alibi_slopes=None,
102
+ window_size=self.window_size,
103
+ deterministic=self.deterministic,
104
+ )
105
+
106
+
107
+ class FlashCrossAttention(nn.Module):
108
+ """Implement the scaled dot product attention with softmax.
109
+ Arguments
110
+ ---------
111
+ softmax_scale: The temperature to use for the softmax attention.
112
+ (default: 1/sqrt(d_keys) where d_keys is computed at
113
+ runtime)
114
+ attention_dropout: The dropout rate to apply to the attention
115
+ (default: 0.0)
116
+ """
117
+
118
+ def __init__(
119
+ self,
120
+ causal=False,
121
+ softmax_scale=None,
122
+ attention_dropout=0.0,
123
+ window_size=(-1, -1),
124
+ deterministic=False,
125
+ ):
126
+ super().__init__()
127
+ assert flash_attn_varlen_kvpacked_func is not None, "FlashAttention is not installed"
128
+ assert flash_attn_kvpacked_func is not None, "FlashAttention is not installed"
129
+ self.causal = causal
130
+ self.softmax_scale = softmax_scale
131
+ self.drop = nn.Dropout(attention_dropout)
132
+ self.window_size = window_size
133
+ self.deterministic = deterministic
134
+
135
+ def forward(
136
+ self,
137
+ q,
138
+ kv,
139
+ causal=None,
140
+ cu_seqlens=None,
141
+ max_seqlen=None,
142
+ cu_seqlens_k=None,
143
+ max_seqlen_k=None,
144
+ ):
145
+ """Implements the multihead softmax attention.
146
+ Arguments
147
+ ---------
148
+ q: The tensor containing the query. (B, Sq, H, D)
149
+ kv: The tensor containing the key and value. (B, Sk, 2, H_k, D)
150
+ causal: if passed, will override self.causal
151
+ cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
152
+ of the sequences in the batch, used to index into q.
153
+ max_seqlen: int. Maximum sequence length in the batch of q.
154
+ cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
155
+ of the sequences in the batch, used to index into kv.
156
+ max_seqlen_k: int. Maximum sequence length in the batch of k and v.
157
+ """
158
+ assert q.dtype in [torch.float16, torch.bfloat16]
159
+ assert q.is_cuda and kv.is_cuda
160
+ causal = self.causal if causal is None else causal
161
+ unpadded = cu_seqlens is not None
162
+
163
+ if unpadded:
164
+ assert cu_seqlens.dtype == torch.int32
165
+ assert max_seqlen is not None
166
+ assert isinstance(max_seqlen, int)
167
+ assert cu_seqlens_k is not None
168
+ assert cu_seqlens_k.dtype == torch.int32
169
+ assert max_seqlen_k is not None
170
+ assert isinstance(max_seqlen, int)
171
+ return flash_attn_varlen_kvpacked_func(
172
+ q,
173
+ kv,
174
+ cu_seqlens,
175
+ cu_seqlens_k,
176
+ max_seqlen,
177
+ max_seqlen_k,
178
+ self.drop.p if self.training else 0.0,
179
+ softmax_scale=self.softmax_scale,
180
+ causal=causal,
181
+ alibi_slopes=None,
182
+ window_size=self.window_size,
183
+ deterministic=self.deterministic,
184
+ )
185
+ else:
186
+ batch_size, seqlen_q = q.shape[0], q.shape[1]
187
+ seqlen_k = kv.shape[1]
188
+ assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3]
189
+ return flash_attn_kvpacked_func(
190
+ q,
191
+ kv,
192
+ self.drop.p if self.training else 0.0,
193
+ causal=causal,
194
+ softmax_scale=self.softmax_scale,
195
+ alibi_slopes=None,
196
+ window_size=self.window_size,
197
+ deterministic=self.deterministic,
198
+ )
199
+
200
+
201
+ class SelfAttention(nn.Module):
202
+ """Implement the scaled dot product attention with softmax.
203
+ Arguments
204
+ ---------
205
+ softmax_scale: The temperature to use for the softmax attention.
206
+ (default: 1/sqrt(d_keys) where d_keys is computed at
207
+ runtime)
208
+ attention_dropout: The dropout rate to apply to the attention
209
+ (default: 0.0)
210
+ """
211
+
212
+ def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
213
+ super().__init__()
214
+ self.causal = causal
215
+ self.softmax_scale = softmax_scale
216
+ self.drop = nn.Dropout(attention_dropout)
217
+
218
+ def forward(self, qkv, causal=None, key_padding_mask=None):
219
+ """Implements the multihead softmax attention.
220
+ Arguments
221
+ ---------
222
+ qkv: The tensor containing the query, key, and value. (B, S, 3, H, D)
223
+ causal: if passed, will override self.causal
224
+ key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
225
+ False means to mask out. (B, S)
226
+ """
227
+ batch_size, seqlen = qkv.shape[0], qkv.shape[1]
228
+ causal = self.causal if causal is None else causal
229
+ q, k, v = qkv.unbind(dim=2)
230
+ softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
231
+ scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
232
+ if key_padding_mask is not None:
233
+ padding_mask = torch.full(
234
+ (batch_size, seqlen), -10000.0, dtype=scores.dtype, device=scores.device
235
+ )
236
+ padding_mask.masked_fill_(key_padding_mask, 0.0)
237
+ # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
238
+ scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
239
+ if causal:
240
+ # "triu_tril_cuda_template" not implemented for 'BFloat16'
241
+ # So we have to construct the mask in float
242
+ causal_mask = torch.triu(
243
+ torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1
244
+ )
245
+ # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
246
+ scores = scores + causal_mask.to(dtype=scores.dtype)
247
+ attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
248
+ attention_drop = self.drop(attention)
249
+ output = torch.einsum("bhts,bshd->bthd", attention_drop, v)
250
+ return output
251
+
252
+
253
+ class CrossAttention(nn.Module):
254
+ """Implement the scaled dot product attention with softmax.
255
+ Arguments
256
+ ---------
257
+ softmax_scale: The temperature to use for the softmax attention.
258
+ (default: 1/sqrt(d_keys) where d_keys is computed at
259
+ runtime)
260
+ attention_dropout: The dropout rate to apply to the attention
261
+ (default: 0.0)
262
+ """
263
+
264
+ def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
265
+ super().__init__()
266
+ self.causal = causal
267
+ self.softmax_scale = softmax_scale
268
+ self.drop = nn.Dropout(attention_dropout)
269
+
270
+ def forward(self, q, kv, causal=None, key_padding_mask=None):
271
+ """Implements the multihead softmax attention.
272
+ Arguments
273
+ ---------
274
+ q: The tensor containing the query. (B, Sq, H, D)
275
+ kv: The tensor containing the key and value. (B, Sk, 2, H_k, D)
276
+ causal: if passed, will override self.causal
277
+ key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
278
+ False means to mask out. (B, Sk)
279
+ """
280
+ batch_size, seqlen_q = q.shape[0], q.shape[1]
281
+ causal = self.causal if causal is None else causal
282
+ seqlen_k = kv.shape[1]
283
+ assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3]
284
+ if kv.shape[3] != q.shape[2]: # MQA/GQA
285
+ kv = repeat(kv, "... hkv d -> ... (hkv g) d", g=q.shape[2] // kv.shape[3])
286
+ k, v = kv.unbind(dim=2)
287
+ softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
288
+ scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
289
+ if key_padding_mask is not None:
290
+ padding_mask = torch.full(
291
+ (batch_size, seqlen_k), -10000.0, dtype=scores.dtype, device=scores.device
292
+ )
293
+ padding_mask.masked_fill_(key_padding_mask, 0.0)
294
+ # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
295
+ scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
296
+ if causal:
297
+ # causal mask needs to take into account the difference between seqlen_q and seqlen_k
298
+ row_idx = rearrange(
299
+ torch.arange(seqlen_q, device=q.device, dtype=torch.long), "s -> s 1"
300
+ )
301
+ col_idx = torch.arange(seqlen_k, device=kv.device, dtype=torch.long)
302
+ sk = (
303
+ seqlen_k
304
+ if key_padding_mask is None
305
+ else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1")
306
+ )
307
+ causal_mask = col_idx > row_idx + sk - seqlen_q
308
+ scores = scores.masked_fill(causal_mask, -10000.0)
309
+ attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
310
+ attention_drop = self.drop(attention)
311
+ output = torch.einsum("bhts,bshd->bthd", attention_drop, v)
312
+ return output
313
+
314
+
315
+ class LinearResidual(nn.Linear):
316
+ """Wrap nn.Linear to return the residual as well. For compatibility with FusedDense."""
317
+
318
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
319
+ return super().forward(input), input
320
+
321
+
322
+ def _update_kv_cache(kv, inference_params, layer_idx):
323
+ """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
324
+ # Pre-allocate memory for key-values for inference.
325
+ num_heads, head_dim = kv.shape[-2:]
326
+ if layer_idx not in inference_params.key_value_memory_dict:
327
+ kv_cache = torch.empty(
328
+ inference_params.max_batch_size,
329
+ inference_params.max_seqlen,
330
+ 2,
331
+ num_heads,
332
+ head_dim,
333
+ dtype=kv.dtype,
334
+ device=kv.device,
335
+ )
336
+ inference_params.key_value_memory_dict[layer_idx] = kv_cache
337
+ else:
338
+ kv_cache = inference_params.key_value_memory_dict[layer_idx]
339
+ # Adjust key and value for inference
340
+ batch_start = inference_params.batch_size_offset
341
+ batch_end = batch_start + kv.shape[0]
342
+ sequence_start = inference_params.seqlen_offset
343
+ sequence_end = sequence_start + kv.shape[1]
344
+ assert batch_end <= kv_cache.shape[0]
345
+ assert sequence_end <= kv_cache.shape[1]
346
+ assert kv_cache is not None
347
+ kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
348
+ return kv_cache[batch_start:batch_end, :sequence_end, ...]
349
+
350
+
351
+ class MHA(nn.Module):
352
+ """Multi-head self-attention and cross-attention"""
353
+
354
+ def __init__(
355
+ self,
356
+ embed_dim,
357
+ num_heads,
358
+ num_heads_kv=None,
359
+ cross_attn=False,
360
+ qkv_proj_bias=True,
361
+ out_proj_bias=True,
362
+ dropout=0.0,
363
+ softmax_scale=None,
364
+ causal=False,
365
+ layer_idx=None,
366
+ dwconv=False,
367
+ window_size=(-1, -1),
368
+ fused_bias_fc=False,
369
+ use_flash_attn=False,
370
+ return_residual=False,
371
+ checkpointing=False,
372
+ device=None,
373
+ dtype=None,
374
+ ) -> None:
375
+ """
376
+ num_heads_kv: can be used to toggle MQA / GQA. If None, use num_heads.
377
+ return_residual: whether to return the input x along with the output. This is for
378
+ performance reason: for post-norm architecture, returning the input allows us
379
+ to fuse the backward of nn.Linear with the residual connection.
380
+ """
381
+ factory_kwargs = {"device": device, "dtype": dtype}
382
+ super().__init__()
383
+ self.embed_dim = embed_dim
384
+ self.cross_attn = cross_attn
385
+ self.causal = causal
386
+ self.layer_idx = layer_idx
387
+ self.dwconv = dwconv
388
+ self.use_flash_attn = use_flash_attn
389
+ self.return_residual = return_residual
390
+ self.checkpointing = checkpointing
391
+
392
+ if window_size != (-1, -1):
393
+ assert use_flash_attn, "Local (sliding window) attention code path requires flash_attn"
394
+
395
+ self.num_heads = num_heads
396
+ self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
397
+ assert (
398
+ self.num_heads % self.num_heads_kv == 0
399
+ ), "num_heads must be divisible by num_heads_kv"
400
+ assert self.embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
401
+ self.head_dim = self.embed_dim // num_heads
402
+ qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
403
+ kv_dim = 2 * self.head_dim * self.num_heads_kv
404
+
405
+ if fused_bias_fc and FusedDense is None:
406
+ raise ImportError("fused_dense is not installed")
407
+ linear_cls = nn.Linear if not fused_bias_fc else FusedDense
408
+ linear_resid_cls = (
409
+ LinearResidual if not fused_bias_fc else partial(FusedDense, return_residual=True)
410
+ )
411
+ wqkv_cls = linear_cls if not self.return_residual else linear_resid_cls
412
+ inner_attn_cls = (
413
+ partial(FlashSelfAttention, window_size=window_size)
414
+ if use_flash_attn
415
+ else SelfAttention
416
+ )
417
+ inner_cross_attn_cls = (
418
+ partial(FlashCrossAttention, window_size=window_size)
419
+ if use_flash_attn
420
+ else CrossAttention
421
+ )
422
+ if not self.cross_attn:
423
+ self.Wqkv = wqkv_cls(embed_dim, qkv_dim, bias=qkv_proj_bias, **factory_kwargs)
424
+ else:
425
+ self.Wq = linear_cls(embed_dim, embed_dim, bias=qkv_proj_bias, **factory_kwargs)
426
+ self.Wkv = wqkv_cls(embed_dim, kv_dim, bias=qkv_proj_bias, **factory_kwargs)
427
+ if self.dwconv:
428
+ if self.num_heads_kv == self.num_heads:
429
+ self.dwconv_qkv = nn.Conv1d(
430
+ qkv_dim, qkv_dim, kernel_size=3, padding=2, groups=qkv_dim
431
+ )
432
+ else:
433
+ self.dwconv_q = nn.Conv1d(
434
+ embed_dim, embed_dim, kernel_size=3, padding=2, groups=embed_dim
435
+ )
436
+ self.dwconv_kv = nn.Conv1d(kv_dim, kv_dim, kernel_size=3, padding=2, groups=kv_dim)
437
+ self.inner_attn = inner_attn_cls(
438
+ causal=causal,
439
+ softmax_scale=softmax_scale,
440
+ attention_dropout=dropout,
441
+ )
442
+ self.inner_cross_attn = inner_cross_attn_cls(
443
+ causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
444
+ )
445
+ self.out_proj = linear_cls(embed_dim, embed_dim, bias=out_proj_bias, **factory_kwargs)
446
+
447
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
448
+ dtype = self.out_proj.weight.dtype if dtype is None else dtype
449
+ device = self.out_proj.weight.device
450
+ return torch.empty(
451
+ batch_size,
452
+ max_seqlen,
453
+ 2,
454
+ self.num_heads_kv,
455
+ self.head_dim,
456
+ dtype=dtype,
457
+ device=device,
458
+ )
459
+
460
+ def _update_kv_cache(self, kv, inference_params):
461
+ """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
462
+ assert not self.dwconv, "Generation does not support dwconv yet"
463
+ assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
464
+ return _update_kv_cache(kv, inference_params, self.layer_idx)
465
+
466
+ def _apply_rotary_update_kvcache_attention(self, q, kv, inference_params):
467
+ """
468
+ Fast path that combine 3 steps: apply rotary to Q and K, update kv cache, and apply attention.
469
+ q: (batch_size, seqlen_q, nheads, head_dim)
470
+ kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim)
471
+ """
472
+ assert inference_params is not None and inference_params.seqlen_offset > 0
473
+ assert self.use_flash_attn
474
+ batch = q.shape[0]
475
+ kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
476
+ cache_seqlens = (
477
+ inference_params.lengths_per_sample[:batch]
478
+ if inference_params.lengths_per_sample is not None
479
+ else inference_params.seqlen_offset
480
+ )
481
+ context = flash_attn_with_kvcache(
482
+ q,
483
+ kv_cache[:, :, 0],
484
+ kv_cache[:, :, 1],
485
+ kv[:, :, 0],
486
+ kv[:, :, 1],
487
+ cache_seqlens=cache_seqlens,
488
+ softmax_scale=self.inner_cross_attn.softmax_scale,
489
+ causal=self.inner_cross_attn.causal,
490
+ rotary_interleaved=False,
491
+ alibi_slopes=None,
492
+ )
493
+ return context
494
+
495
+ def _update_kvcache_attention(self, q, kv, inference_params):
496
+ """Write kv to inference_params, then do attention"""
497
+ if (
498
+ inference_params.seqlen_offset == 0
499
+ or flash_attn_with_kvcache is None
500
+ or not self.use_flash_attn
501
+ ):
502
+ # TODO: this only uses seqlen_offset and not lengths_per_sample.
503
+ kv = self._update_kv_cache(kv, inference_params)
504
+ return self.inner_cross_attn(q, kv)
505
+ else:
506
+ batch = q.shape[0]
507
+ kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
508
+ cache_seqlens = (
509
+ inference_params.lengths_per_sample[:batch]
510
+ if inference_params.lengths_per_sample is not None
511
+ else inference_params.seqlen_offset
512
+ )
513
+ return flash_attn_with_kvcache(
514
+ q,
515
+ kv_cache[:, :, 0],
516
+ kv_cache[:, :, 1],
517
+ kv[:, :, 0],
518
+ kv[:, :, 1],
519
+ cache_seqlens=cache_seqlens,
520
+ softmax_scale=self.inner_cross_attn.softmax_scale,
521
+ causal=self.inner_cross_attn.causal,
522
+ alibi_slopes=None,
523
+ )
524
+
525
+ def forward(
526
+ self,
527
+ x,
528
+ x_kv=None,
529
+ key_padding_mask=None,
530
+ cu_seqlens=None,
531
+ max_seqlen=None,
532
+ mixer_subset=None,
533
+ inference_params=None,
534
+ **kwargs,
535
+ ):
536
+ """
537
+ Arguments:
538
+ x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if
539
+ cu_seqlens is None and max_seqlen is None, else (total, hidden_dim) where total
540
+ is the is the sum of the sequence lengths in the batch.
541
+ x_kv: (batch, seqlen, hidden_dim), only applicable for cross-attention. If None, use x.
542
+ cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
543
+ of the sequences in the batch, used to index into x. Only applicable when using
544
+ FlashAttention.
545
+ max_seqlen: int. Maximum sequence length in the batch.
546
+ key_padding_mask: boolean mask, True means to keep, False means to mask out.
547
+ (batch, seqlen). Only applicable when not using FlashAttention.
548
+ mixer_subset: for cross-attention only. If not None, will take a subset of x
549
+ before applying the query projection. Useful for e.g., ViT where we only care
550
+ about the CLS token in the last layer.
551
+ inference_params: for generation. Adapted from Megatron-LM (and Apex)
552
+ https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
553
+ """
554
+ if cu_seqlens is not None:
555
+ assert max_seqlen is not None
556
+ assert key_padding_mask is None
557
+ assert self.use_flash_attn
558
+ assert not self.dwconv
559
+ if key_padding_mask is not None:
560
+ assert cu_seqlens is None
561
+ assert max_seqlen is None
562
+ assert not self.use_flash_attn
563
+ if inference_params is not None:
564
+ assert key_padding_mask is None
565
+ assert cu_seqlens is None and max_seqlen is None
566
+ assert not self.dwconv
567
+
568
+ kwargs = (
569
+ {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen, **kwargs}
570
+ if self.use_flash_attn
571
+ else {"key_padding_mask": key_padding_mask, **kwargs}
572
+ )
573
+ seqlen_offset = (
574
+ 0
575
+ if inference_params is None
576
+ else (
577
+ inference_params.lengths_per_sample
578
+ if inference_params.lengths_per_sample is not None
579
+ else inference_params.seqlen_offset
580
+ )
581
+ )
582
+ rotary_max_seqlen = (
583
+ inference_params.max_sequence_len if inference_params is not None else max_seqlen
584
+ )
585
+ batch, seqlen = x.shape[:2]
586
+ if not self.cross_attn and self.num_heads_kv == self.num_heads:
587
+ assert x_kv is None and mixer_subset is None
588
+ if not self.return_residual:
589
+ qkv = self.Wqkv(x)
590
+ else:
591
+ qkv, x = self.Wqkv(x)
592
+ if self.dwconv:
593
+ qkv = rearrange(
594
+ self.dwconv_qkv(rearrange(qkv, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
595
+ ).contiguous()
596
+ qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim)
597
+ if (
598
+ inference_params is None
599
+ or inference_params.seqlen_offset == 0
600
+ or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
601
+ or not self.use_flash_attn
602
+ ):
603
+ if inference_params is None:
604
+ if not self.checkpointing:
605
+ context = self.inner_attn(qkv, **kwargs)
606
+ else:
607
+ context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs)
608
+ else:
609
+ context = self._update_kvcache_attention(
610
+ qkv[:, :, 0], qkv[:, :, 1:], inference_params
611
+ )
612
+ else:
613
+ context = self._apply_rotary_update_kvcache_attention(
614
+ qkv[:, :, 0], qkv[:, :, 1:], inference_params
615
+ )
616
+ else:
617
+ if self.cross_attn:
618
+ if not self.return_residual:
619
+ q = self.Wq(x if mixer_subset is None else x[:, mixer_subset])
620
+ kv = self.Wkv(x_kv if x_kv is not None else x)
621
+ else:
622
+ if x_kv is not None:
623
+ kv, x_kv = self.Wkv(x_kv)
624
+ else:
625
+ kv, x = self.Wkv(x)
626
+ q = self.Wq(x if mixer_subset is None else x[:, mixer_subset])
627
+ else:
628
+ assert self.num_heads_kv != self.num_heads
629
+ if not self.return_residual:
630
+ qkv = self.Wqkv(x)
631
+ else:
632
+ qkv, x = self.Wqkv(x)
633
+ q = qkv[..., : self.num_heads * self.head_dim]
634
+ kv = qkv[..., self.num_heads * self.head_dim :]
635
+ q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim)
636
+ kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim)
637
+ if self.dwconv:
638
+ q = rearrange(
639
+ self.dwconv_q(rearrange(q, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
640
+ ).contiguous()
641
+ kv = rearrange(
642
+ self.dwconv_kv(rearrange(kv, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
643
+ ).contiguous()
644
+ if (
645
+ inference_params is None
646
+ or inference_params.seqlen_offset == 0
647
+ or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
648
+ or not self.use_flash_attn
649
+ ):
650
+ if inference_params is None:
651
+ if not self.checkpointing:
652
+ context = self.inner_cross_attn(q, kv, **kwargs)
653
+ else:
654
+ context = torch.utils.checkpoint.checkpoint(
655
+ self.inner_cross_attn, q, kv, **kwargs
656
+ )
657
+ else:
658
+ context = self._update_kvcache_attention(q, kv, inference_params)
659
+ else:
660
+ context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
661
+ out = self.out_proj(rearrange(context, "... h d -> ... (h d)"))
662
+ return out if not self.return_residual else (out, x)
mlp.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This implementation was adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mlp.py
2
+ # Commit id: c3b219665292c61a51153d0ded4473c494296382
3
+
4
+ # Copyright (c) 2023, Tri Dao.
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from torch.distributed import ProcessGroup
10
+
11
+
12
+ try:
13
+ from flash_attn.ops.activations import swiglu
14
+ except ImportError:
15
+ swiglu = None
16
+
17
+ try:
18
+ from flash_attn.ops.fused_dense import ColumnParallelLinear, RowParallelLinear
19
+ except ImportError:
20
+ ColumnParallelLinear, RowParallelLinear = None, None
21
+
22
+ try:
23
+ from flash_attn.ops.fused_dense import FusedMLP, ParallelFusedMLP
24
+ except ImportError:
25
+ FusedMLP, ParallelFusedMLP = None, None
26
+
27
+
28
+ class Mlp(nn.Module):
29
+ def __init__(
30
+ self,
31
+ in_features,
32
+ hidden_features=None,
33
+ out_features=None,
34
+ activation=F.gelu,
35
+ bias1=True,
36
+ bias2=True,
37
+ return_residual=False,
38
+ device=None,
39
+ dtype=None,
40
+ ):
41
+ factory_kwargs = {"device": device, "dtype": dtype}
42
+ super().__init__()
43
+ out_features = out_features if out_features is not None else in_features
44
+ hidden_features = hidden_features if hidden_features is not None else in_features * 4
45
+ self.return_residual = return_residual
46
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1, **factory_kwargs)
47
+ self.activation = activation
48
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
49
+
50
+ def forward(self, x):
51
+ y = self.fc1(x)
52
+ y = self.activation(y)
53
+ y = self.fc2(y)
54
+ return y if not self.return_residual else (y, x)
55
+
56
+
57
+ class ParallelMLP(nn.Module):
58
+ def __init__(
59
+ self,
60
+ in_features,
61
+ hidden_features=None,
62
+ out_features=None,
63
+ activation=F.gelu,
64
+ process_group: ProcessGroup = None,
65
+ sequence_parallel=True,
66
+ bias1=True,
67
+ bias2=True,
68
+ device=None,
69
+ dtype=None,
70
+ ):
71
+ factory_kwargs = {"device": device, "dtype": dtype}
72
+ super().__init__()
73
+ assert ColumnParallelLinear is not None, "Need to install fused_dense"
74
+ assert RowParallelLinear is not None, "Need to install fused_dense"
75
+ out_features = out_features if out_features is not None else in_features
76
+ hidden_features = hidden_features if hidden_features is not None else in_features * 4
77
+ self.fc1 = ColumnParallelLinear(
78
+ in_features,
79
+ hidden_features,
80
+ process_group,
81
+ bias=bias1,
82
+ sequence_parallel=sequence_parallel,
83
+ **factory_kwargs,
84
+ )
85
+ self.activation = activation
86
+ self.fc2 = RowParallelLinear(
87
+ hidden_features,
88
+ out_features,
89
+ process_group,
90
+ bias=bias2,
91
+ sequence_parallel=sequence_parallel,
92
+ **factory_kwargs,
93
+ )
94
+
95
+ def forward(self, x):
96
+ y = self.fc1(x)
97
+ y = self.activation(y)
98
+ y = self.fc2(y)
99
+ return y
100
+
101
+
102
+ class GatedMlp(nn.Module):
103
+ def __init__(
104
+ self,
105
+ in_features,
106
+ hidden_features=None,
107
+ out_features=None,
108
+ activation=F.sigmoid,
109
+ bias1=True,
110
+ bias2=True,
111
+ multiple_of=128,
112
+ return_residual=False,
113
+ device=None,
114
+ dtype=None,
115
+ ):
116
+ factory_kwargs = {"device": device, "dtype": dtype}
117
+ super().__init__()
118
+ out_features = out_features if out_features is not None else in_features
119
+ hidden_features = (
120
+ hidden_features if hidden_features is not None else int(8 * in_features / 3)
121
+ )
122
+ hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of
123
+ self.return_residual = return_residual
124
+ self.fc1 = nn.Linear(in_features, 2 * hidden_features, bias=bias1, **factory_kwargs)
125
+ self.activation = activation
126
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
127
+
128
+ def forward(self, x):
129
+ y = self.fc1(x)
130
+ if self.activation == F.sigmoid: # Special case for GLU
131
+ y = F.glu(y, dim=-1)
132
+ elif self.activation == F.silu and swiglu is not None: # Special case for SwiGLU
133
+ y, gate = y.chunk(2, dim=-1)
134
+ y = swiglu(gate, y)
135
+ else:
136
+ y, gate = y.chunk(2, dim=-1)
137
+ y = y * self.activation(gate)
138
+ y = self.fc2(y)
139
+ return y if not self.return_residual else (y, x)
140
+
141
+
142
+ class ParallelGatedMlp(nn.Module):
143
+ """Parallel GatedMlp"""
144
+
145
+ def __init__(
146
+ self,
147
+ in_features,
148
+ process_group,
149
+ hidden_features=None,
150
+ out_features=None,
151
+ activation=F.sigmoid,
152
+ bias1=True,
153
+ bias2=True,
154
+ multiple_of=128,
155
+ sequence_parallel=True,
156
+ device=None,
157
+ dtype=None,
158
+ ):
159
+ factory_kwargs = {"device": device, "dtype": dtype}
160
+ super().__init__()
161
+ out_features = out_features if out_features is not None else in_features
162
+ hidden_features = (
163
+ hidden_features if hidden_features is not None else int(8 * in_features / 3)
164
+ )
165
+ hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of
166
+ if ColumnParallelLinear is None or RowParallelLinear is None:
167
+ raise ImportError("fused_dense is not installed")
168
+ self.fc1 = ColumnParallelLinear(
169
+ in_features,
170
+ 2 * hidden_features,
171
+ process_group,
172
+ bias=bias1,
173
+ sequence_parallel=sequence_parallel,
174
+ **factory_kwargs,
175
+ )
176
+ self.activation = activation
177
+ self.fc2 = RowParallelLinear(
178
+ hidden_features,
179
+ out_features,
180
+ process_group,
181
+ bias=bias2,
182
+ sequence_parallel=sequence_parallel,
183
+ **factory_kwargs,
184
+ )
185
+
186
+ def forward(self, x):
187
+ y = self.fc1(x)
188
+ if self.activation == F.sigmoid: # Special case for GLU
189
+ y = F.glu(y, dim=-1)
190
+ else:
191
+ y, gate = y.chunk(2, dim=-1)
192
+ y = y * self.activation(gate)
193
+ y = self.fc2(y)
194
+ return y
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ec9522eeeca8dcbf5aeb01e57db7c153140a204d2890fe9f6732b94af0524804
3
+ size 556892306
special_tokens_map.json ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<s>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "cls_token": {
10
+ "content": "<s>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "eos_token": {
17
+ "content": "</s>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "mask_token": {
24
+ "content": "<mask>",
25
+ "lstrip": true,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ },
30
+ "pad_token": {
31
+ "content": "<pad>",
32
+ "lstrip": false,
33
+ "normalized": false,
34
+ "rstrip": false,
35
+ "single_word": false
36
+ },
37
+ "sep_token": {
38
+ "content": "</s>",
39
+ "lstrip": false,
40
+ "normalized": false,
41
+ "rstrip": false,
42
+ "single_word": false
43
+ },
44
+ "unk_token": {
45
+ "content": "<unk>",
46
+ "lstrip": false,
47
+ "normalized": false,
48
+ "rstrip": false,
49
+ "single_word": false
50
+ }
51
+ }
stochastic_depth.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Implementation modified from torchvision:
2
+ # https://github.com/pytorch/vision/blob/main/torchvision/ops/stochastic_depth.py
3
+ #
4
+ # License:
5
+ # BSD 3-Clause License
6
+ #
7
+ # Copyright (c) Soumith Chintala 2016,
8
+ # All rights reserved.
9
+ #
10
+ # Redistribution and use in source and binary forms, with or without
11
+ # modification, are permitted provided that the following conditions are met:
12
+ #
13
+ # * Redistributions of source code must retain the above copyright notice, this
14
+ # list of conditions and the following disclaimer.
15
+ #
16
+ # * Redistributions in binary form must reproduce the above copyright notice,
17
+ # this list of conditions and the following disclaimer in the documentation
18
+ # and/or other materials provided with the distribution.
19
+ #
20
+ # * Neither the name of the copyright holder nor the names of its
21
+ # contributors may be used to endorse or promote products derived from
22
+ # this software without specific prior written permission.
23
+ #
24
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
25
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
26
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
27
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
28
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
29
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
30
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
31
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
32
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
33
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
34
+
35
+ import torch
36
+ import torch.fx
37
+ from torch import nn, Tensor
38
+
39
+
40
+ def stochastic_depth(
41
+ input: Tensor, p: float, mode: str, training: bool = True
42
+ ) -> Tensor:
43
+ """
44
+ Implements the Stochastic Depth from `"Deep Networks with Stochastic Depth"
45
+ <https://arxiv.org/abs/1603.09382>`_ used for randomly dropping residual
46
+ branches of residual architectures.
47
+
48
+ Args:
49
+ input (Tensor[N, ...]): The input tensor or arbitrary dimensions with the first one
50
+ being its batch i.e. a batch with ``N`` rows.
51
+ p (float): probability of the input to be zeroed.
52
+ mode (str): ``"batch"`` or ``"row"``.
53
+ ``"batch"`` randomly zeroes the entire input, ``"row"`` zeroes
54
+ randomly selected rows from the batch.
55
+ training: apply stochastic depth if is ``True``. Default: ``True``
56
+
57
+ Returns:
58
+ Tensor[N, ...]: The randomly zeroed tensor.
59
+ """
60
+ if p < 0.0 or p > 1.0:
61
+ raise ValueError(f"drop probability has to be between 0 and 1, but got {p}")
62
+ if mode not in ["batch", "row"]:
63
+ raise ValueError(f"mode has to be either 'batch' or 'row', but got {mode}")
64
+ if not training or p == 0.0:
65
+ return input
66
+
67
+ survival_rate = 1.0 - p
68
+ if mode == "row":
69
+ size = [input.shape[0]] + [1] * (input.ndim - 1)
70
+ else:
71
+ size = [1] * input.ndim
72
+ noise = torch.empty(size, dtype=input.dtype, device=input.device)
73
+ noise = noise.bernoulli_(survival_rate)
74
+ if survival_rate > 0.0:
75
+ noise.div_(survival_rate)
76
+ return input * noise
77
+
78
+
79
+ torch.fx.wrap("stochastic_depth")
80
+
81
+
82
+ class StochasticDepth(nn.Module):
83
+ """
84
+ See :func:`stochastic_depth`.
85
+ """
86
+
87
+ def __init__(self, p: float, mode: str) -> None:
88
+ super().__init__()
89
+ self.p = p
90
+ self.mode = mode
91
+
92
+ def forward(self, input: Tensor) -> Tensor:
93
+ return stochastic_depth(input, self.p, self.mode, self.training)
94
+
95
+ def __repr__(self) -> str:
96
+ s = f"{self.__class__.__name__}(p={self.p}, mode={self.mode})"
97
+ return s
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b1e0376b0ca081a6b0c18125d251f214835d1165944f9eac39baf8d9cf2b15fe
3
+ size 17082832
tokenizer_config.json ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "<s>",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "1": {
12
+ "content": "<pad>",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "2": {
20
+ "content": "</s>",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "3": {
28
+ "content": "<unk>",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "250001": {
36
+ "content": "<mask>",
37
+ "lstrip": true,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ }
43
+ },
44
+ "bos_token": "<s>",
45
+ "clean_up_tokenization_spaces": true,
46
+ "cls_token": "<s>",
47
+ "eos_token": "</s>",
48
+ "mask_token": "<mask>",
49
+ "max_length": 768,
50
+ "model_max_length": 1024,
51
+ "pad_token": "<pad>",
52
+ "sep_token": "</s>",
53
+ "stride": 0,
54
+ "tokenizer_class": "XLMRobertaTokenizer",
55
+ "truncation_side": "right",
56
+ "truncation_strategy": "longest_first",
57
+ "unk_token": "<unk>"
58
+ }
xlm_padding.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This implementation was adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/block.py
2
+ # Commit id: c94cd09744d20f0ac587a351ff6ff2e8ad11ae1b
3
+
4
+ # Previously adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from einops import rearrange, repeat
9
+
10
+
11
+ class IndexFirstAxis(torch.autograd.Function):
12
+ @staticmethod
13
+ def forward(ctx, input, indices):
14
+ ctx.save_for_backward(indices)
15
+ assert input.ndim >= 2
16
+ ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
17
+ second_dim = other_shape.numel()
18
+ # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
19
+ # return input[indices]
20
+ return torch.gather(
21
+ rearrange(input, "b ... -> b (...)"), 0, repeat(indices, "z -> z d", d=second_dim)
22
+ ).reshape(-1, *other_shape)
23
+
24
+ @staticmethod
25
+ def backward(ctx, grad_output):
26
+ (indices,) = ctx.saved_tensors
27
+ assert grad_output.ndim >= 2
28
+ other_shape = grad_output.shape[1:]
29
+ grad_output = rearrange(grad_output, "b ... -> b (...)")
30
+ grad_input = torch.zeros(
31
+ [ctx.first_axis_dim, grad_output.shape[1]],
32
+ device=grad_output.device,
33
+ dtype=grad_output.dtype,
34
+ )
35
+ # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
36
+ # grad_input[indices] = grad_output
37
+ grad_input.scatter_(0, repeat(indices, "z -> z d", d=grad_output.shape[1]), grad_output)
38
+ return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
39
+
40
+
41
+ index_first_axis = IndexFirstAxis.apply
42
+
43
+
44
+ class IndexPutFirstAxis(torch.autograd.Function):
45
+ @staticmethod
46
+ def forward(ctx, values, indices, first_axis_dim):
47
+ ctx.save_for_backward(indices)
48
+ assert indices.ndim == 1
49
+ assert values.ndim >= 2
50
+ output = torch.zeros(
51
+ first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype
52
+ )
53
+ # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
54
+ output[indices] = values
55
+ # output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values)
56
+ return output
57
+
58
+ @staticmethod
59
+ def backward(ctx, grad_output):
60
+ (indices,) = ctx.saved_tensors
61
+ # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
62
+ grad_values = grad_output[indices]
63
+ # grad_values = torch.gather(grad_output, 0, repeat(indices, 'z -> z d', d=grad_output.shape[1]))
64
+ return grad_values, None, None
65
+
66
+
67
+ index_put_first_axis = IndexPutFirstAxis.apply
68
+
69
+
70
+ class IndexFirstAxisResidual(torch.autograd.Function):
71
+ @staticmethod
72
+ def forward(ctx, input, indices):
73
+ ctx.save_for_backward(indices)
74
+ assert input.ndim >= 2
75
+ ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
76
+ second_dim = other_shape.numel()
77
+ # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
78
+ output = input[indices]
79
+ # We don't want to reshape input (b ... -> b (...)) since it could change the channel_last
80
+ # memory format to channel_first. In other words, input might not be contiguous.
81
+ # If we don't detach, Pytorch complains about output being a view and is being modified inplace
82
+ return output, input.detach()
83
+
84
+ @staticmethod
85
+ def backward(ctx, grad_output, grad_residual):
86
+ (indices,) = ctx.saved_tensors
87
+ assert grad_output.ndim >= 2
88
+ other_shape = grad_output.shape[1:]
89
+ assert grad_residual.shape[1:] == other_shape
90
+ grad_input = grad_residual
91
+ # grad_input[indices] += grad_output
92
+ indices = indices.reshape(indices.shape[0], *((1,) * (grad_output.ndim - 1)))
93
+ indices = indices.expand_as(grad_output)
94
+ grad_input.scatter_add_(0, indices, grad_output)
95
+ return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
96
+
97
+
98
+ index_first_axis_residual = IndexFirstAxisResidual.apply
99
+
100
+
101
+ def unpad_input(hidden_states, attention_mask):
102
+ """
103
+ Arguments:
104
+ hidden_states: (batch, seqlen, ...)
105
+ attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
106
+ Return:
107
+ hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
108
+ indices: (total_nnz), the indices of non-masked tokens from the flattened input sequence.
109
+ cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
110
+ max_seqlen_in_batch: int
111
+ """
112
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
113
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
114
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
115
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
116
+ # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
117
+ # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
118
+ # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
119
+ # index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
120
+ # so we write custom forward and backward to make it a bit faster.
121
+ return (
122
+ index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices),
123
+ indices,
124
+ cu_seqlens,
125
+ max_seqlen_in_batch,
126
+ )
127
+
128
+
129
+ def unpad_input_for_concatenated_sequences(hidden_states, attention_mask_in_length):
130
+ """
131
+ Supports concatenating short samples in one sequence. The attention_mask_in_length is utilized to mask other short samples. It helps efficient training of variant lengths-based samples (e.g., the supervised fine-tuning task in large language model).
132
+ The motivation for this function is explained [here](https://github.com/Dao-AILab/flash-attention/issues/432#issuecomment-1668822286).
133
+
134
+ For example, if batch = 3 and seqlen = 6, the attention_mask_in_length is:
135
+ ```
136
+ [
137
+ [2, 3, 0, 0, 0, 0],
138
+ [3, 2, 0, 0, 0, 0],
139
+ [6, 0, 0, 0, 0, 0]
140
+ ]
141
+ ```
142
+ , which refers to the 3D-attention mask:
143
+ ```
144
+ [
145
+ [
146
+ [1, 0, 0, 0, 0, 0],
147
+ [1, 1, 0, 0, 0, 0],
148
+ [0, 0, 1, 0, 0, 0],
149
+ [0, 0, 1, 1, 0, 0],
150
+ [0, 0, 1, 1, 1, 0],
151
+ [0, 0, 0, 0, 0, 1]
152
+ ],
153
+ [
154
+ [1, 0, 0, 0, 0, 0],
155
+ [1, 1, 0, 0, 0, 0],
156
+ [1, 1, 1, 0, 0, 0],
157
+ [0, 0, 0, 1, 0, 0],
158
+ [0, 0, 0, 1, 1, 0],
159
+ [0, 0, 0, 0, 0, 1]
160
+ ],
161
+ [
162
+ [1, 0, 0, 0, 0, 0],
163
+ [1, 1, 0, 0, 0, 0],
164
+ [1, 1, 1, 0, 0, 0],
165
+ [1, 1, 1, 1, 0, 0],
166
+ [1, 1, 1, 1, 1, 0],
167
+ [1, 1, 1, 1, 1, 1]
168
+ ]
169
+ ]
170
+ ```.
171
+
172
+ Arguments:
173
+ hidden_states: (batch, seqlen, ...)
174
+ attention_mask_in_length: (batch, seqlen), int, a nonzero number (e.g., 1, 2, 3, etc.) means length of concatenated sequence in b-th batch, and 0 means none.
175
+ Return:
176
+ hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
177
+ indices: (total_nnz), the indices of non-masked tokens from the flattened input sequence.
178
+ cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
179
+ max_seqlen_in_batch: int
180
+ """
181
+ length = attention_mask_in_length.sum(dim=-1)
182
+ seqlen = attention_mask_in_length.size(-1)
183
+ attention_mask_2d = torch.arange(seqlen, device=length.device, dtype=length.dtype).expand(len(length),
184
+ seqlen) < length.unsqueeze(
185
+ 1)
186
+ real_indices_idx = torch.nonzero(attention_mask_in_length.flatten(), as_tuple=False).flatten()
187
+ seqlens_in_batch = attention_mask_in_length.flatten()[real_indices_idx]
188
+ indices = torch.nonzero(attention_mask_2d.flatten(), as_tuple=False).flatten()
189
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
190
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
191
+ # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
192
+ # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
193
+ # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
194
+ # index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
195
+ # so we write custom forward and backward to make it a bit faster.
196
+ return (
197
+ index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices),
198
+ indices,
199
+ cu_seqlens,
200
+ max_seqlen_in_batch,
201
+ )
202
+
203
+
204
+ def pad_input(hidden_states, indices, batch, seqlen):
205
+ """
206
+ Arguments:
207
+ hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
208
+ indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence.
209
+ batch: int, batch size for the padded sequence.
210
+ seqlen: int, maximum sequence length for the padded sequence.
211
+ Return:
212
+ hidden_states: (batch, seqlen, ...)
213
+ """
214
+ dim = hidden_states.shape[-1]
215
+ # output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype)
216
+ # output[indices] = hidden_states
217
+ output = index_put_first_axis(hidden_states, indices, batch * seqlen)
218
+ return rearrange(output, "(b s) ... -> b s ...", b=batch)