pt-sk commited on
Commit
315f1bc
1 Parent(s): d0824ce

Upload 2 files

Browse files
Files changed (2) hide show
  1. code/model.py +375 -0
  2. code/parallel_scan.py +226 -0
code/model.py ADDED
@@ -0,0 +1,375 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Simple, minimal implementation of Mamba in one file of PyTorch.
2
+
3
+ Suggest reading the following before/while reading the code:
4
+ [1] Mamba: Linear-Time Sequence Modeling with Selective State Spaces (Albert Gu and Tri Dao)
5
+ https://arxiv.org/abs/2312.00752
6
+ [2] The Annotated S4 (Sasha Rush and Sidd Karamcheti)
7
+ https://srush.github.io/annotated-s4
8
+
9
+ Glossary:
10
+ b: batch size (`B` in Mamba paper [1] Algorithm 2)
11
+ l: sequence length (`L` in [1] Algorithm 2)
12
+ d or d_model: hidden dim
13
+ n or d_state: latent state dim (`N` in [1] Algorithm 2)
14
+ expand: expansion factor (`E` in [1] Section 3.4)
15
+ d_in or d_inner: d * expand (`D` in [1] Algorithm 2)
16
+ A, B, C, D: state space parameters (See any state space representation formula)
17
+ (B, C are input-dependent (aka selective, a key innovation in Mamba); A, D are not)
18
+ Δ or delta: input-dependent step size
19
+ dt_rank: rank of Δ (See [1] Section 3.6 "Parameterization of ∆")
20
+
21
+ """
22
+ from __future__ import annotations
23
+ import math
24
+ import json
25
+ import torch
26
+ import torch.nn as nn
27
+ import torch.nn.functional as F
28
+ from dataclasses import dataclass
29
+ from typing import Union
30
+ from einops import rearrange, repeat, einsum
31
+ from parallel_scan import pscan
32
+
33
+
34
+ @dataclass
35
+ class ModelArgs:
36
+ d_model: int
37
+ n_layer: int
38
+ vocab_size: int
39
+ d_state: int = 16
40
+ expand: int = 2
41
+ dt_rank: Union[int, str] = 'auto'
42
+ d_conv: int = 4
43
+ pad_vocab_size_multiple: int = 8
44
+ conv_bias: bool = True
45
+ bias: bool = False
46
+
47
+ def __post_init__(self):
48
+ self.d_inner = int(self.expand * self.d_model)
49
+
50
+ if self.dt_rank == 'auto':
51
+ self.dt_rank = math.ceil(self.d_model / 16)
52
+
53
+ if self.vocab_size % self.pad_vocab_size_multiple != 0:
54
+ self.vocab_size += (self.pad_vocab_size_multiple
55
+ - self.vocab_size % self.pad_vocab_size_multiple)
56
+
57
+
58
+ class Mamba(nn.Module):
59
+ def __init__(self, args: ModelArgs):
60
+ """Full Mamba model."""
61
+ super().__init__()
62
+ self.args = args
63
+
64
+ self.embedding = nn.Embedding(args.vocab_size, args.d_model)
65
+ self.layers = nn.ModuleList([ResidualBlock(args) for _ in range(args.n_layer)])
66
+ self.norm_f = RMSNorm(args.d_model)
67
+
68
+ self.lm_head = nn.Linear(args.d_model, args.vocab_size, bias=False)
69
+ self.lm_head.weight = self.embedding.weight # Tie output projection to embedding weights.
70
+ # See "Weight Tying" paper
71
+
72
+
73
+ def forward(self, input_ids):
74
+ """
75
+ Args:
76
+ input_ids (long tensor): shape (b, l) (See Glossary at top for definitions of b, l, d_in, n...)
77
+
78
+ Returns:
79
+ logits: shape (b, l, vocab_size)
80
+
81
+ Official Implementation:
82
+ class MambaLMHeadModel, https://github.com/state-spaces/mamba/blob/main/mamba_ssm/models/mixer_seq_simple.py#L173
83
+
84
+ """
85
+ x = self.embedding(input_ids)
86
+
87
+ for layer in self.layers:
88
+ x = layer(x)
89
+
90
+ x = self.norm_f(x)
91
+ logits = self.lm_head(x)
92
+
93
+ return logits
94
+
95
+
96
+ @staticmethod
97
+ def from_config(pretrained_model_name: str):
98
+ from transformers.utils import CONFIG_NAME
99
+ from transformers.utils.hub import cached_file
100
+
101
+ def load_config_hf(model_name):
102
+ resolved_archive_file = cached_file(model_name, CONFIG_NAME,
103
+ _raise_exceptions_for_missing_entries=False)
104
+ return json.load(open(resolved_archive_file))
105
+ config_data = load_config_hf(pretrained_model_name)
106
+ args = ModelArgs(
107
+ d_model=config_data['d_model'],
108
+ n_layer=config_data['n_layer'],
109
+ vocab_size=config_data['vocab_size']
110
+ )
111
+ model = Mamba(args)
112
+ return model
113
+
114
+
115
+ @staticmethod
116
+ def from_pretrained(pretrained_model_name: str):
117
+ """Load pretrained weights from HuggingFace into model.
118
+
119
+ Args:
120
+ pretrained_model_name: One of
121
+ * 'state-spaces/mamba-2.8b-slimpj'
122
+ * 'state-spaces/mamba-2.8b'
123
+ * 'state-spaces/mamba-1.4b'
124
+ * 'state-spaces/mamba-790m'
125
+ * 'state-spaces/mamba-370m'
126
+ * 'state-spaces/mamba-130m'
127
+
128
+ Returns:
129
+ model: Mamba model with weights loaded
130
+
131
+ """
132
+ from transformers.utils import WEIGHTS_NAME, CONFIG_NAME
133
+ from transformers.utils.hub import cached_file
134
+
135
+ def load_config_hf(model_name):
136
+ resolved_archive_file = cached_file(model_name, CONFIG_NAME,
137
+ _raise_exceptions_for_missing_entries=False)
138
+ return json.load(open(resolved_archive_file))
139
+
140
+
141
+ def load_state_dict_hf(model_name, device=None, dtype=None):
142
+ resolved_archive_file = cached_file(model_name, WEIGHTS_NAME,
143
+ _raise_exceptions_for_missing_entries=False)
144
+ return torch.load(resolved_archive_file, weights_only=True, map_location='cpu', mmap=True)
145
+
146
+ config_data = load_config_hf(pretrained_model_name)
147
+ args = ModelArgs(
148
+ d_model=config_data['d_model'],
149
+ n_layer=config_data['n_layer'],
150
+ vocab_size=config_data['vocab_size']
151
+ )
152
+ model = Mamba(args)
153
+
154
+ state_dict = load_state_dict_hf(pretrained_model_name)
155
+ new_state_dict = {}
156
+ for key in state_dict:
157
+ new_key = key.replace('backbone.', '')
158
+ new_state_dict[new_key] = state_dict[key]
159
+ model.load_state_dict(new_state_dict)
160
+
161
+ return model
162
+
163
+
164
+ class ResidualBlock(nn.Module):
165
+ def __init__(self, args: ModelArgs):
166
+ """Simple block wrapping Mamba block with normalization and residual connection."""
167
+ super().__init__()
168
+ self.args = args
169
+ self.mixer = MambaBlock(args)
170
+ self.norm = RMSNorm(args.d_model)
171
+
172
+
173
+ def forward(self, x):
174
+ """
175
+ Args:
176
+ x: shape (b, l, d) (See Glossary at top for definitions of b, l, d_in, n...)
177
+
178
+ Returns:
179
+ output: shape (b, l, d)
180
+
181
+ Official Implementation:
182
+ Block.forward(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py#L297
183
+
184
+ Note: the official repo chains residual blocks that look like
185
+ [Add -> Norm -> Mamba] -> [Add -> Norm -> Mamba] -> [Add -> Norm -> Mamba] -> ...
186
+ where the first Add is a no-op. This is purely for performance reasons as this
187
+ allows them to fuse the Add->Norm.
188
+
189
+ We instead implement our blocks as the more familiar, simpler, and numerically equivalent
190
+ [Norm -> Mamba -> Add] -> [Norm -> Mamba -> Add] -> [Norm -> Mamba -> Add] -> ....
191
+
192
+ """
193
+ output = self.mixer(self.norm(x)) + x
194
+
195
+ return output
196
+
197
+
198
+ class MambaBlock(nn.Module):
199
+ def __init__(self, args: ModelArgs):
200
+ """A single Mamba block, as described in Figure 3 in Section 3.4 in the Mamba paper [1]."""
201
+ super().__init__()
202
+ self.args = args
203
+
204
+ self.in_proj = nn.Linear(args.d_model, args.d_inner * 2, bias=args.bias)
205
+
206
+ self.conv1d = nn.Conv1d(
207
+ in_channels=args.d_inner,
208
+ out_channels=args.d_inner,
209
+ bias=args.conv_bias,
210
+ kernel_size=args.d_conv,
211
+ groups=args.d_inner,
212
+ padding=args.d_conv - 1,
213
+ )
214
+
215
+ # x_proj takes in `x` and outputs the input-specific Δ, B, C
216
+ self.x_proj = nn.Linear(args.d_inner, args.dt_rank + args.d_state * 2, bias=False)
217
+
218
+ # dt_proj projects Δ from dt_rank to d_in
219
+ self.dt_proj = nn.Linear(args.dt_rank, args.d_inner, bias=True)
220
+
221
+ A = repeat(torch.arange(1, args.d_state + 1), 'n -> d n', d=args.d_inner)
222
+ self.A_log = nn.Parameter(torch.log(A))
223
+ self.D = nn.Parameter(torch.ones(args.d_inner))
224
+ self.out_proj = nn.Linear(args.d_inner, args.d_model, bias=args.bias)
225
+
226
+
227
+ def forward(self, x):
228
+ """Mamba block forward. This looks the same as Figure 3 in Section 3.4 in the Mamba paper [1].
229
+
230
+ Args:
231
+ x: shape (b, l, d) (See Glossary at top for definitions of b, l, d_in, n...)
232
+
233
+ Returns:
234
+ output: shape (b, l, d)
235
+
236
+ Official Implementation:
237
+ class Mamba, https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py#L119
238
+ mamba_inner_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L311
239
+
240
+ """
241
+ (b, l, d) = x.shape
242
+
243
+ x_and_res = self.in_proj(x) # shape (b, l, 2 * d_in)
244
+ (x, res) = x_and_res.split(split_size=[self.args.d_inner, self.args.d_inner], dim=-1)
245
+
246
+ x = rearrange(x, 'b l d_in -> b d_in l')
247
+ x = self.conv1d(x)[:, :, :l]
248
+ x = rearrange(x, 'b d_in l -> b l d_in')
249
+
250
+ x = F.silu(x)
251
+
252
+ y = self.ssm(x)
253
+
254
+ y = y * F.silu(res)
255
+
256
+ output = self.out_proj(y)
257
+
258
+ return output
259
+
260
+
261
+ def ssm(self, x):
262
+ """Runs the SSM. See:
263
+ - Algorithm 2 in Section 3.2 in the Mamba paper [1]
264
+ - run_SSM(A, B, C, u) in The Annotated S4 [2]
265
+
266
+ Args:
267
+ x: shape (b, l, d_in) (See Glossary at top for definitions of b, l, d_in, n...)
268
+
269
+ Returns:
270
+ output: shape (b, l, d_in)
271
+
272
+ Official Implementation:
273
+ mamba_inner_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L311
274
+
275
+ """
276
+ (d_in, n) = self.A_log.shape
277
+
278
+ # Compute ∆ A B C D, the state space parameters.
279
+ # A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective)
280
+ # ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4,
281
+ # and is why Mamba is called **selective** state spaces)
282
+
283
+ A = -torch.exp(self.A_log.float()) # shape (d_in, n)
284
+ D = self.D.float()
285
+
286
+ x_dbl = self.x_proj(x) # (b, l, dt_rank + 2*n)
287
+
288
+ (delta, B, C) = x_dbl.split(split_size=[self.args.dt_rank, n, n], dim=-1) # delta: (b, l, dt_rank). B, C: (b, l, n)
289
+ delta = F.softplus(self.dt_proj(delta)) # (b, l, d_in)
290
+
291
+ y = self.selective_scan(x, delta, A, B, C, D) # This is similar to run_SSM(A, B, C, u) in The Annotated S4 [2]
292
+
293
+ return y
294
+
295
+
296
+ def selective_scan(self, x, delta, A, B, C, D):
297
+ """Does selective scan algorithm. See:
298
+ - Section 2 State Space Models in the Mamba paper [1]
299
+ - Algorithm 2 in Section 3.2 in the Mamba paper [1]
300
+ - run_SSM(A, B, C, u) in The Annotated S4 [2]
301
+
302
+ This is the classic discrete state space formula:
303
+ x(t + 1) = Ax(t) + Bu(t)
304
+ y(t) = Cx(t) + Du(t)
305
+ except B and C (and the step size delta, which is used for discretization) are dependent on the input x(t).
306
+
307
+ Args:
308
+ u: shape (b, l, d_in) (See Glossary at top for definitions of b, l, d_in, n...)
309
+ delta: shape (b, l, d_in)
310
+ A: shape (d_in, n)
311
+ B: shape (b, l, n)
312
+ C: shape (b, l, n)
313
+ D: shape (d_in,)
314
+
315
+ Returns:
316
+ output: shape (b, l, d_in)
317
+
318
+ Official Implementation:
319
+ selective_scan_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L86
320
+ Note: I refactored some parts out of `selective_scan_ref` out, so the functionality doesn't match exactly.
321
+
322
+ """
323
+ # sequential scan
324
+ # (b, l, d_in) = u.shape
325
+ # n = A.shape[1]
326
+
327
+ # # Discretize continuous parameters (A, B)
328
+ # # - A is discretized using zero-order hold (ZOH) discretization (see Section 2 Equation 4 in the Mamba paper [1])
329
+ # # - B is discretized using a simplified Euler discretization instead of ZOH. From a discussion with authors:
330
+ # # "A is the more important term and the performance doesn't change much with the simplification on B"
331
+ # deltaA = torch.exp(einsum(delta, A, 'b l d_in, d_in n -> b l d_in n'))
332
+ # deltaB_u = einsum(delta, B, u, 'b l d_in, b l n, b l d_in -> b l d_in n')
333
+
334
+ # # Perform selective scan (see scan_SSM() in The Annotated S4 [2])
335
+ # # Note that the below is sequential, while the official implementation does a much faster parallel scan that
336
+ # # is additionally hardware-aware (like FlashAttention).
337
+ # x = torch.zeros((b, d_in, n), device=deltaA.device)
338
+ # ys = []
339
+ # for i in range(l):
340
+ # x = deltaA[:, i] * x + deltaB_u[:, i]
341
+ # y = einsum(x, C[:, i, :], 'b d_in n, b n -> b d_in')
342
+ # ys.append(y)
343
+ # y = torch.stack(ys, dim=1) # shape (b, l, d_in)
344
+
345
+ # y = y + u * D
346
+
347
+ # return y
348
+ # parallel scan
349
+ deltaA = torch.exp(delta.unsqueeze(-1) * A) # (B, L, ED, N)
350
+ deltaB = delta.unsqueeze(-1) * B.unsqueeze(2) # (B, L, ED, N)
351
+
352
+ BX = deltaB * (x.unsqueeze(-1)) # (B, L, ED, N)
353
+
354
+ hs = pscan(deltaA, BX)
355
+
356
+ y = (hs @ C.unsqueeze(-1)).squeeze(3) # (B, L, ED, N) @ (B, L, N, 1) -> (B, L, ED, 1)
357
+
358
+ y = y + D * x
359
+
360
+ return y
361
+
362
+
363
+ class RMSNorm(nn.Module):
364
+ def __init__(self,
365
+ d_model: int,
366
+ eps: float = 1e-5):
367
+ super().__init__()
368
+ self.eps = eps
369
+ self.weight = nn.Parameter(torch.ones(d_model))
370
+
371
+
372
+ def forward(self, x):
373
+ output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight
374
+
375
+ return output
code/parallel_scan.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+ """
7
+
8
+ An implementation of the parallel scan operation in PyTorch (Blelloch version).
9
+ Please see docs/pscan.ipynb for a detailed explanation of what happens here.
10
+
11
+ """
12
+
13
+ def npo2(len):
14
+ """
15
+ Returns the next power of 2 above len
16
+ """
17
+
18
+ return 2 ** math.ceil(math.log2(len))
19
+
20
+ def pad_npo2(X):
21
+ """
22
+ Pads input length dim to the next power of 2
23
+
24
+ Args:
25
+ X : (B, L, D, N)
26
+
27
+ Returns:
28
+ Y : (B, npo2(L), D, N)
29
+ """
30
+
31
+ len_npo2 = npo2(X.size(1))
32
+ pad_tuple = (0, 0, 0, 0, 0, len_npo2 - X.size(1))
33
+ return F.pad(X, pad_tuple, "constant", 0)
34
+
35
+ class PScan(torch.autograd.Function):
36
+ @staticmethod
37
+ def pscan(A, X):
38
+ # A : (B, D, L, N)
39
+ # X : (B, D, L, N)
40
+
41
+ # modifies X in place by doing a parallel scan.
42
+ # more formally, X will be populated by these values :
43
+ # H[t] = A[t] * H[t-1] + X[t] with H[0] = 0
44
+ # which are computed in parallel (2*log2(T) sequential steps (ideally), instead of T sequential steps)
45
+
46
+ # only supports L that is a power of two (mainly for a clearer code)
47
+
48
+ B, D, L, _ = A.size()
49
+ num_steps = int(math.log2(L))
50
+
51
+ # up sweep (last 2 steps unfolded)
52
+ Aa = A
53
+ Xa = X
54
+ for _ in range(num_steps-2):
55
+ T = Xa.size(2)
56
+ Aa = Aa.view(B, D, T//2, 2, -1)
57
+ Xa = Xa.view(B, D, T//2, 2, -1)
58
+
59
+ Xa[:, :, :, 1].add_(Aa[:, :, :, 1].mul(Xa[:, :, :, 0]))
60
+ Aa[:, :, :, 1].mul_(Aa[:, :, :, 0])
61
+
62
+ Aa = Aa[:, :, :, 1]
63
+ Xa = Xa[:, :, :, 1]
64
+
65
+ # we have only 4, 2 or 1 nodes left
66
+ if Xa.size(2) == 4:
67
+ Xa[:, :, 1].add_(Aa[:, :, 1].mul(Xa[:, :, 0]))
68
+ Aa[:, :, 1].mul_(Aa[:, :, 0])
69
+
70
+ Xa[:, :, 3].add_(Aa[:, :, 3].mul(Xa[:, :, 2] + Aa[:, :, 2].mul(Xa[:, :, 1])))
71
+ elif Xa.size(2) == 2:
72
+ Xa[:, :, 1].add_(Aa[:, :, 1].mul(Xa[:, :, 0]))
73
+ return
74
+ else:
75
+ return
76
+
77
+ # down sweep (first 2 steps unfolded)
78
+ Aa = A[:, :, 2**(num_steps-2)-1:L:2**(num_steps-2)]
79
+ Xa = X[:, :, 2**(num_steps-2)-1:L:2**(num_steps-2)]
80
+ Xa[:, :, 2].add_(Aa[:, :, 2].mul(Xa[:, :, 1]))
81
+ Aa[:, :, 2].mul_(Aa[:, :, 1])
82
+
83
+ for k in range(num_steps-3, -1, -1):
84
+ Aa = A[:, :, 2**k-1:L:2**k]
85
+ Xa = X[:, :, 2**k-1:L:2**k]
86
+
87
+ T = Xa.size(2)
88
+ Aa = Aa.view(B, D, T//2, 2, -1)
89
+ Xa = Xa.view(B, D, T//2, 2, -1)
90
+
91
+ Xa[:, :, 1:, 0].add_(Aa[:, :, 1:, 0].mul(Xa[:, :, :-1, 1]))
92
+ Aa[:, :, 1:, 0].mul_(Aa[:, :, :-1, 1])
93
+
94
+ @staticmethod
95
+ def pscan_rev(A, X):
96
+ # A : (B, D, L, N)
97
+ # X : (B, D, L, N)
98
+
99
+ # the same function as above, but in reverse
100
+ # (if you flip the input, call pscan, then flip the output, you get what this function outputs)
101
+ # it is used in the backward pass
102
+
103
+ # only supports L that is a power of two (mainly for a clearer code)
104
+
105
+ B, D, L, _ = A.size()
106
+ num_steps = int(math.log2(L))
107
+
108
+ # up sweep (last 2 steps unfolded)
109
+ Aa = A
110
+ Xa = X
111
+ for _ in range(num_steps-2):
112
+ T = Xa.size(2)
113
+ Aa = Aa.view(B, D, T//2, 2, -1)
114
+ Xa = Xa.view(B, D, T//2, 2, -1)
115
+
116
+ Xa[:, :, :, 0].add_(Aa[:, :, :, 0].mul(Xa[:, :, :, 1]))
117
+ Aa[:, :, :, 0].mul_(Aa[:, :, :, 1])
118
+
119
+ Aa = Aa[:, :, :, 0]
120
+ Xa = Xa[:, :, :, 0]
121
+
122
+ # we have only 4, 2 or 1 nodes left
123
+ if Xa.size(2) == 4:
124
+ Xa[:, :, 2].add_(Aa[:, :, 2].mul(Xa[:, :, 3]))
125
+ Aa[:, :, 2].mul_(Aa[:, :, 3])
126
+
127
+ Xa[:, :, 0].add_(Aa[:, :, 0].mul(Xa[:, :, 1].add(Aa[:, :, 1].mul(Xa[:, :, 2]))))
128
+ elif Xa.size(2) == 2:
129
+ Xa[:, :, 0].add_(Aa[:, :, 0].mul(Xa[:, :, 1]))
130
+ return
131
+ else:
132
+ return
133
+
134
+ # down sweep (first 2 steps unfolded)
135
+ Aa = A[:, :, 0:L:2**(num_steps-2)]
136
+ Xa = X[:, :, 0:L:2**(num_steps-2)]
137
+ Xa[:, :, 1].add_(Aa[:, :, 1].mul(Xa[:, :, 2]))
138
+ Aa[:, :, 1].mul_(Aa[:, :, 2])
139
+
140
+ for k in range(num_steps-3, -1, -1):
141
+ Aa = A[:, :, 0:L:2**k]
142
+ Xa = X[:, :, 0:L:2**k]
143
+
144
+ T = Xa.size(2)
145
+ Aa = Aa.view(B, D, T//2, 2, -1)
146
+ Xa = Xa.view(B, D, T//2, 2, -1)
147
+
148
+ Xa[:, :, :-1, 1].add_(Aa[:, :, :-1, 1].mul(Xa[:, :, 1:, 0]))
149
+ Aa[:, :, :-1, 1].mul_(Aa[:, :, 1:, 0])
150
+
151
+ @staticmethod
152
+ def forward(ctx, A_in, X_in):
153
+ """
154
+ Applies the parallel scan operation, as defined above. Returns a new tensor.
155
+ If you can, privilege sequence lengths that are powers of two.
156
+
157
+ Args:
158
+ A_in : (B, L, D, N)
159
+ X_in : (B, L, D, N)
160
+
161
+ Returns:
162
+ H : (B, L, D, N)
163
+ """
164
+
165
+ L = X_in.size(1)
166
+
167
+ # cloning is requiered because of the in-place ops
168
+ if L == npo2(L):
169
+ A = A_in.clone()
170
+ X = X_in.clone()
171
+ else:
172
+ # pad tensors (and clone btw)
173
+ A = pad_npo2(A_in) # (B, npo2(L), D, N)
174
+ X = pad_npo2(X_in) # (B, npo2(L), D, N)
175
+
176
+ # prepare tensors
177
+ A = A.transpose(2, 1) # (B, D, npo2(L), N)
178
+ X = X.transpose(2, 1) # (B, D, npo2(L), N)
179
+
180
+ # parallel scan (modifies X in-place)
181
+ PScan.pscan(A, X)
182
+
183
+ ctx.save_for_backward(A_in, X)
184
+
185
+ # slice [:, :L] (cut if there was padding)
186
+ return X.transpose(2, 1)[:, :L]
187
+
188
+ @staticmethod
189
+ def backward(ctx, grad_output_in):
190
+ """
191
+ Flows the gradient from the output to the input. Returns two new tensors.
192
+
193
+ Args:
194
+ ctx : A_in : (B, L, D, N), X : (B, D, L, N)
195
+ grad_output_in : (B, L, D, N)
196
+
197
+ Returns:
198
+ gradA : (B, L, D, N), gradX : (B, L, D, N)
199
+ """
200
+
201
+ A_in, X = ctx.saved_tensors
202
+
203
+ L = grad_output_in.size(1)
204
+
205
+ # cloning is requiered because of the in-place ops
206
+ if L == npo2(L):
207
+ grad_output = grad_output_in.clone()
208
+ # the next padding will clone A_in
209
+ else:
210
+ grad_output = pad_npo2(grad_output_in) # (B, npo2(L), D, N)
211
+ A_in = pad_npo2(A_in) # (B, npo2(L), D, N)
212
+
213
+ # prepare tensors
214
+ grad_output = grad_output.transpose(2, 1)
215
+ A_in = A_in.transpose(2, 1) # (B, D, npo2(L), N)
216
+ A = torch.nn.functional.pad(A_in[:, :, 1:], (0, 0, 0, 1)) # (B, D, npo2(L), N) shift 1 to the left (see hand derivation)
217
+
218
+ # reverse parallel scan (modifies grad_output in-place)
219
+ PScan.pscan_rev(A, grad_output)
220
+
221
+ Q = torch.zeros_like(X)
222
+ Q[:, :, 1:].add_(X[:, :, :-1] * grad_output[:, :, 1:])
223
+
224
+ return Q.transpose(2, 1)[:, :L], grad_output.transpose(2, 1)[:, :L]
225
+
226
+ pscan = PScan.apply