pt-sk commited on
Commit
9add2bc
1 Parent(s): aee6996

Delete model.py

Browse files
Files changed (1) hide show
  1. model.py +0 -375
model.py DELETED
@@ -1,375 +0,0 @@
1
- """Simple, minimal implementation of Mamba in one file of PyTorch.
2
-
3
- Suggest reading the following before/while reading the code:
4
- [1] Mamba: Linear-Time Sequence Modeling with Selective State Spaces (Albert Gu and Tri Dao)
5
- https://arxiv.org/abs/2312.00752
6
- [2] The Annotated S4 (Sasha Rush and Sidd Karamcheti)
7
- https://srush.github.io/annotated-s4
8
-
9
- Glossary:
10
- b: batch size (`B` in Mamba paper [1] Algorithm 2)
11
- l: sequence length (`L` in [1] Algorithm 2)
12
- d or d_model: hidden dim
13
- n or d_state: latent state dim (`N` in [1] Algorithm 2)
14
- expand: expansion factor (`E` in [1] Section 3.4)
15
- d_in or d_inner: d * expand (`D` in [1] Algorithm 2)
16
- A, B, C, D: state space parameters (See any state space representation formula)
17
- (B, C are input-dependent (aka selective, a key innovation in Mamba); A, D are not)
18
- Δ or delta: input-dependent step size
19
- dt_rank: rank of Δ (See [1] Section 3.6 "Parameterization of ∆")
20
-
21
- """
22
- from __future__ import annotations
23
- import math
24
- import json
25
- import torch
26
- import torch.nn as nn
27
- import torch.nn.functional as F
28
- from dataclasses import dataclass
29
- from typing import Union
30
- from einops import rearrange, repeat, einsum
31
- from parallel_scan import pscan
32
-
33
-
34
- @dataclass
35
- class ModelArgs:
36
- d_model: int
37
- n_layer: int
38
- vocab_size: int
39
- d_state: int = 16
40
- expand: int = 2
41
- dt_rank: Union[int, str] = 'auto'
42
- d_conv: int = 4
43
- pad_vocab_size_multiple: int = 8
44
- conv_bias: bool = True
45
- bias: bool = False
46
-
47
- def __post_init__(self):
48
- self.d_inner = int(self.expand * self.d_model)
49
-
50
- if self.dt_rank == 'auto':
51
- self.dt_rank = math.ceil(self.d_model / 16)
52
-
53
- if self.vocab_size % self.pad_vocab_size_multiple != 0:
54
- self.vocab_size += (self.pad_vocab_size_multiple
55
- - self.vocab_size % self.pad_vocab_size_multiple)
56
-
57
-
58
- class Mamba(nn.Module):
59
- def __init__(self, args: ModelArgs):
60
- """Full Mamba model."""
61
- super().__init__()
62
- self.args = args
63
-
64
- self.embedding = nn.Embedding(args.vocab_size, args.d_model)
65
- self.layers = nn.ModuleList([ResidualBlock(args) for _ in range(args.n_layer)])
66
- self.norm_f = RMSNorm(args.d_model)
67
-
68
- self.lm_head = nn.Linear(args.d_model, args.vocab_size, bias=False)
69
- self.lm_head.weight = self.embedding.weight # Tie output projection to embedding weights.
70
- # See "Weight Tying" paper
71
-
72
-
73
- def forward(self, input_ids):
74
- """
75
- Args:
76
- input_ids (long tensor): shape (b, l) (See Glossary at top for definitions of b, l, d_in, n...)
77
-
78
- Returns:
79
- logits: shape (b, l, vocab_size)
80
-
81
- Official Implementation:
82
- class MambaLMHeadModel, https://github.com/state-spaces/mamba/blob/main/mamba_ssm/models/mixer_seq_simple.py#L173
83
-
84
- """
85
- x = self.embedding(input_ids)
86
-
87
- for layer in self.layers:
88
- x = layer(x)
89
-
90
- x = self.norm_f(x)
91
- logits = self.lm_head(x)
92
-
93
- return logits
94
-
95
-
96
- @staticmethod
97
- def from_config(pretrained_model_name: str):
98
- from transformers.utils import CONFIG_NAME
99
- from transformers.utils.hub import cached_file
100
-
101
- def load_config_hf(model_name):
102
- resolved_archive_file = cached_file(model_name, CONFIG_NAME,
103
- _raise_exceptions_for_missing_entries=False)
104
- return json.load(open(resolved_archive_file))
105
- config_data = load_config_hf(pretrained_model_name)
106
- args = ModelArgs(
107
- d_model=config_data['d_model'],
108
- n_layer=config_data['n_layer'],
109
- vocab_size=config_data['vocab_size']
110
- )
111
- model = Mamba(args)
112
- return model
113
-
114
-
115
- @staticmethod
116
- def from_pretrained(pretrained_model_name: str):
117
- """Load pretrained weights from HuggingFace into model.
118
-
119
- Args:
120
- pretrained_model_name: One of
121
- * 'state-spaces/mamba-2.8b-slimpj'
122
- * 'state-spaces/mamba-2.8b'
123
- * 'state-spaces/mamba-1.4b'
124
- * 'state-spaces/mamba-790m'
125
- * 'state-spaces/mamba-370m'
126
- * 'state-spaces/mamba-130m'
127
-
128
- Returns:
129
- model: Mamba model with weights loaded
130
-
131
- """
132
- from transformers.utils import WEIGHTS_NAME, CONFIG_NAME
133
- from transformers.utils.hub import cached_file
134
-
135
- def load_config_hf(model_name):
136
- resolved_archive_file = cached_file(model_name, CONFIG_NAME,
137
- _raise_exceptions_for_missing_entries=False)
138
- return json.load(open(resolved_archive_file))
139
-
140
-
141
- def load_state_dict_hf(model_name, device=None, dtype=None):
142
- resolved_archive_file = cached_file(model_name, WEIGHTS_NAME,
143
- _raise_exceptions_for_missing_entries=False)
144
- return torch.load(resolved_archive_file, weights_only=True, map_location='cpu', mmap=True)
145
-
146
- config_data = load_config_hf(pretrained_model_name)
147
- args = ModelArgs(
148
- d_model=config_data['d_model'],
149
- n_layer=config_data['n_layer'],
150
- vocab_size=config_data['vocab_size']
151
- )
152
- model = Mamba(args)
153
-
154
- state_dict = load_state_dict_hf(pretrained_model_name)
155
- new_state_dict = {}
156
- for key in state_dict:
157
- new_key = key.replace('backbone.', '')
158
- new_state_dict[new_key] = state_dict[key]
159
- model.load_state_dict(new_state_dict)
160
-
161
- return model
162
-
163
-
164
- class ResidualBlock(nn.Module):
165
- def __init__(self, args: ModelArgs):
166
- """Simple block wrapping Mamba block with normalization and residual connection."""
167
- super().__init__()
168
- self.args = args
169
- self.mixer = MambaBlock(args)
170
- self.norm = RMSNorm(args.d_model)
171
-
172
-
173
- def forward(self, x):
174
- """
175
- Args:
176
- x: shape (b, l, d) (See Glossary at top for definitions of b, l, d_in, n...)
177
-
178
- Returns:
179
- output: shape (b, l, d)
180
-
181
- Official Implementation:
182
- Block.forward(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py#L297
183
-
184
- Note: the official repo chains residual blocks that look like
185
- [Add -> Norm -> Mamba] -> [Add -> Norm -> Mamba] -> [Add -> Norm -> Mamba] -> ...
186
- where the first Add is a no-op. This is purely for performance reasons as this
187
- allows them to fuse the Add->Norm.
188
-
189
- We instead implement our blocks as the more familiar, simpler, and numerically equivalent
190
- [Norm -> Mamba -> Add] -> [Norm -> Mamba -> Add] -> [Norm -> Mamba -> Add] -> ....
191
-
192
- """
193
- output = self.mixer(self.norm(x)) + x
194
-
195
- return output
196
-
197
-
198
- class MambaBlock(nn.Module):
199
- def __init__(self, args: ModelArgs):
200
- """A single Mamba block, as described in Figure 3 in Section 3.4 in the Mamba paper [1]."""
201
- super().__init__()
202
- self.args = args
203
-
204
- self.in_proj = nn.Linear(args.d_model, args.d_inner * 2, bias=args.bias)
205
-
206
- self.conv1d = nn.Conv1d(
207
- in_channels=args.d_inner,
208
- out_channels=args.d_inner,
209
- bias=args.conv_bias,
210
- kernel_size=args.d_conv,
211
- groups=args.d_inner,
212
- padding=args.d_conv - 1,
213
- )
214
-
215
- # x_proj takes in `x` and outputs the input-specific Δ, B, C
216
- self.x_proj = nn.Linear(args.d_inner, args.dt_rank + args.d_state * 2, bias=False)
217
-
218
- # dt_proj projects Δ from dt_rank to d_in
219
- self.dt_proj = nn.Linear(args.dt_rank, args.d_inner, bias=True)
220
-
221
- A = repeat(torch.arange(1, args.d_state + 1), 'n -> d n', d=args.d_inner)
222
- self.A_log = nn.Parameter(torch.log(A))
223
- self.D = nn.Parameter(torch.ones(args.d_inner))
224
- self.out_proj = nn.Linear(args.d_inner, args.d_model, bias=args.bias)
225
-
226
-
227
- def forward(self, x):
228
- """Mamba block forward. This looks the same as Figure 3 in Section 3.4 in the Mamba paper [1].
229
-
230
- Args:
231
- x: shape (b, l, d) (See Glossary at top for definitions of b, l, d_in, n...)
232
-
233
- Returns:
234
- output: shape (b, l, d)
235
-
236
- Official Implementation:
237
- class Mamba, https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py#L119
238
- mamba_inner_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L311
239
-
240
- """
241
- (b, l, d) = x.shape
242
-
243
- x_and_res = self.in_proj(x) # shape (b, l, 2 * d_in)
244
- (x, res) = x_and_res.split(split_size=[self.args.d_inner, self.args.d_inner], dim=-1)
245
-
246
- x = rearrange(x, 'b l d_in -> b d_in l')
247
- x = self.conv1d(x)[:, :, :l]
248
- x = rearrange(x, 'b d_in l -> b l d_in')
249
-
250
- x = F.silu(x)
251
-
252
- y = self.ssm(x)
253
-
254
- y = y * F.silu(res)
255
-
256
- output = self.out_proj(y)
257
-
258
- return output
259
-
260
-
261
- def ssm(self, x):
262
- """Runs the SSM. See:
263
- - Algorithm 2 in Section 3.2 in the Mamba paper [1]
264
- - run_SSM(A, B, C, u) in The Annotated S4 [2]
265
-
266
- Args:
267
- x: shape (b, l, d_in) (See Glossary at top for definitions of b, l, d_in, n...)
268
-
269
- Returns:
270
- output: shape (b, l, d_in)
271
-
272
- Official Implementation:
273
- mamba_inner_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L311
274
-
275
- """
276
- (d_in, n) = self.A_log.shape
277
-
278
- # Compute ∆ A B C D, the state space parameters.
279
- # A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective)
280
- # ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4,
281
- # and is why Mamba is called **selective** state spaces)
282
-
283
- A = -torch.exp(self.A_log.float()) # shape (d_in, n)
284
- D = self.D.float()
285
-
286
- x_dbl = self.x_proj(x) # (b, l, dt_rank + 2*n)
287
-
288
- (delta, B, C) = x_dbl.split(split_size=[self.args.dt_rank, n, n], dim=-1) # delta: (b, l, dt_rank). B, C: (b, l, n)
289
- delta = F.softplus(self.dt_proj(delta)) # (b, l, d_in)
290
-
291
- y = self.selective_scan(x, delta, A, B, C, D) # This is similar to run_SSM(A, B, C, u) in The Annotated S4 [2]
292
-
293
- return y
294
-
295
-
296
- def selective_scan(self, x, delta, A, B, C, D):
297
- """Does selective scan algorithm. See:
298
- - Section 2 State Space Models in the Mamba paper [1]
299
- - Algorithm 2 in Section 3.2 in the Mamba paper [1]
300
- - run_SSM(A, B, C, u) in The Annotated S4 [2]
301
-
302
- This is the classic discrete state space formula:
303
- x(t + 1) = Ax(t) + Bu(t)
304
- y(t) = Cx(t) + Du(t)
305
- except B and C (and the step size delta, which is used for discretization) are dependent on the input x(t).
306
-
307
- Args:
308
- u: shape (b, l, d_in) (See Glossary at top for definitions of b, l, d_in, n...)
309
- delta: shape (b, l, d_in)
310
- A: shape (d_in, n)
311
- B: shape (b, l, n)
312
- C: shape (b, l, n)
313
- D: shape (d_in,)
314
-
315
- Returns:
316
- output: shape (b, l, d_in)
317
-
318
- Official Implementation:
319
- selective_scan_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L86
320
- Note: I refactored some parts out of `selective_scan_ref` out, so the functionality doesn't match exactly.
321
-
322
- """
323
- # sequential scan
324
- # (b, l, d_in) = u.shape
325
- # n = A.shape[1]
326
-
327
- # # Discretize continuous parameters (A, B)
328
- # # - A is discretized using zero-order hold (ZOH) discretization (see Section 2 Equation 4 in the Mamba paper [1])
329
- # # - B is discretized using a simplified Euler discretization instead of ZOH. From a discussion with authors:
330
- # # "A is the more important term and the performance doesn't change much with the simplification on B"
331
- # deltaA = torch.exp(einsum(delta, A, 'b l d_in, d_in n -> b l d_in n'))
332
- # deltaB_u = einsum(delta, B, u, 'b l d_in, b l n, b l d_in -> b l d_in n')
333
-
334
- # # Perform selective scan (see scan_SSM() in The Annotated S4 [2])
335
- # # Note that the below is sequential, while the official implementation does a much faster parallel scan that
336
- # # is additionally hardware-aware (like FlashAttention).
337
- # x = torch.zeros((b, d_in, n), device=deltaA.device)
338
- # ys = []
339
- # for i in range(l):
340
- # x = deltaA[:, i] * x + deltaB_u[:, i]
341
- # y = einsum(x, C[:, i, :], 'b d_in n, b n -> b d_in')
342
- # ys.append(y)
343
- # y = torch.stack(ys, dim=1) # shape (b, l, d_in)
344
-
345
- # y = y + u * D
346
-
347
- # return y
348
- # parallel scan
349
- deltaA = torch.exp(delta.unsqueeze(-1) * A) # (B, L, ED, N)
350
- deltaB = delta.unsqueeze(-1) * B.unsqueeze(2) # (B, L, ED, N)
351
-
352
- BX = deltaB * (x.unsqueeze(-1)) # (B, L, ED, N)
353
-
354
- hs = pscan(deltaA, BX)
355
-
356
- y = (hs @ C.unsqueeze(-1)).squeeze(3) # (B, L, ED, N) @ (B, L, N, 1) -> (B, L, ED, 1)
357
-
358
- y = y + D * x
359
-
360
- return y
361
-
362
-
363
- class RMSNorm(nn.Module):
364
- def __init__(self,
365
- d_model: int,
366
- eps: float = 1e-5):
367
- super().__init__()
368
- self.eps = eps
369
- self.weight = nn.Parameter(torch.ones(d_model))
370
-
371
-
372
- def forward(self, x):
373
- output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight
374
-
375
- return output