pt-sk commited on
Commit
2977f4d
1 Parent(s): 9add2bc

Delete parallel_scan.py

Browse files
Files changed (1) hide show
  1. parallel_scan.py +0 -226
parallel_scan.py DELETED
@@ -1,226 +0,0 @@
1
- import math
2
-
3
- import torch
4
- import torch.nn.functional as F
5
-
6
- """
7
-
8
- An implementation of the parallel scan operation in PyTorch (Blelloch version).
9
- Please see docs/pscan.ipynb for a detailed explanation of what happens here.
10
-
11
- """
12
-
13
- def npo2(len):
14
- """
15
- Returns the next power of 2 above len
16
- """
17
-
18
- return 2 ** math.ceil(math.log2(len))
19
-
20
- def pad_npo2(X):
21
- """
22
- Pads input length dim to the next power of 2
23
-
24
- Args:
25
- X : (B, L, D, N)
26
-
27
- Returns:
28
- Y : (B, npo2(L), D, N)
29
- """
30
-
31
- len_npo2 = npo2(X.size(1))
32
- pad_tuple = (0, 0, 0, 0, 0, len_npo2 - X.size(1))
33
- return F.pad(X, pad_tuple, "constant", 0)
34
-
35
- class PScan(torch.autograd.Function):
36
- @staticmethod
37
- def pscan(A, X):
38
- # A : (B, D, L, N)
39
- # X : (B, D, L, N)
40
-
41
- # modifies X in place by doing a parallel scan.
42
- # more formally, X will be populated by these values :
43
- # H[t] = A[t] * H[t-1] + X[t] with H[0] = 0
44
- # which are computed in parallel (2*log2(T) sequential steps (ideally), instead of T sequential steps)
45
-
46
- # only supports L that is a power of two (mainly for a clearer code)
47
-
48
- B, D, L, _ = A.size()
49
- num_steps = int(math.log2(L))
50
-
51
- # up sweep (last 2 steps unfolded)
52
- Aa = A
53
- Xa = X
54
- for _ in range(num_steps-2):
55
- T = Xa.size(2)
56
- Aa = Aa.view(B, D, T//2, 2, -1)
57
- Xa = Xa.view(B, D, T//2, 2, -1)
58
-
59
- Xa[:, :, :, 1].add_(Aa[:, :, :, 1].mul(Xa[:, :, :, 0]))
60
- Aa[:, :, :, 1].mul_(Aa[:, :, :, 0])
61
-
62
- Aa = Aa[:, :, :, 1]
63
- Xa = Xa[:, :, :, 1]
64
-
65
- # we have only 4, 2 or 1 nodes left
66
- if Xa.size(2) == 4:
67
- Xa[:, :, 1].add_(Aa[:, :, 1].mul(Xa[:, :, 0]))
68
- Aa[:, :, 1].mul_(Aa[:, :, 0])
69
-
70
- Xa[:, :, 3].add_(Aa[:, :, 3].mul(Xa[:, :, 2] + Aa[:, :, 2].mul(Xa[:, :, 1])))
71
- elif Xa.size(2) == 2:
72
- Xa[:, :, 1].add_(Aa[:, :, 1].mul(Xa[:, :, 0]))
73
- return
74
- else:
75
- return
76
-
77
- # down sweep (first 2 steps unfolded)
78
- Aa = A[:, :, 2**(num_steps-2)-1:L:2**(num_steps-2)]
79
- Xa = X[:, :, 2**(num_steps-2)-1:L:2**(num_steps-2)]
80
- Xa[:, :, 2].add_(Aa[:, :, 2].mul(Xa[:, :, 1]))
81
- Aa[:, :, 2].mul_(Aa[:, :, 1])
82
-
83
- for k in range(num_steps-3, -1, -1):
84
- Aa = A[:, :, 2**k-1:L:2**k]
85
- Xa = X[:, :, 2**k-1:L:2**k]
86
-
87
- T = Xa.size(2)
88
- Aa = Aa.view(B, D, T//2, 2, -1)
89
- Xa = Xa.view(B, D, T//2, 2, -1)
90
-
91
- Xa[:, :, 1:, 0].add_(Aa[:, :, 1:, 0].mul(Xa[:, :, :-1, 1]))
92
- Aa[:, :, 1:, 0].mul_(Aa[:, :, :-1, 1])
93
-
94
- @staticmethod
95
- def pscan_rev(A, X):
96
- # A : (B, D, L, N)
97
- # X : (B, D, L, N)
98
-
99
- # the same function as above, but in reverse
100
- # (if you flip the input, call pscan, then flip the output, you get what this function outputs)
101
- # it is used in the backward pass
102
-
103
- # only supports L that is a power of two (mainly for a clearer code)
104
-
105
- B, D, L, _ = A.size()
106
- num_steps = int(math.log2(L))
107
-
108
- # up sweep (last 2 steps unfolded)
109
- Aa = A
110
- Xa = X
111
- for _ in range(num_steps-2):
112
- T = Xa.size(2)
113
- Aa = Aa.view(B, D, T//2, 2, -1)
114
- Xa = Xa.view(B, D, T//2, 2, -1)
115
-
116
- Xa[:, :, :, 0].add_(Aa[:, :, :, 0].mul(Xa[:, :, :, 1]))
117
- Aa[:, :, :, 0].mul_(Aa[:, :, :, 1])
118
-
119
- Aa = Aa[:, :, :, 0]
120
- Xa = Xa[:, :, :, 0]
121
-
122
- # we have only 4, 2 or 1 nodes left
123
- if Xa.size(2) == 4:
124
- Xa[:, :, 2].add_(Aa[:, :, 2].mul(Xa[:, :, 3]))
125
- Aa[:, :, 2].mul_(Aa[:, :, 3])
126
-
127
- Xa[:, :, 0].add_(Aa[:, :, 0].mul(Xa[:, :, 1].add(Aa[:, :, 1].mul(Xa[:, :, 2]))))
128
- elif Xa.size(2) == 2:
129
- Xa[:, :, 0].add_(Aa[:, :, 0].mul(Xa[:, :, 1]))
130
- return
131
- else:
132
- return
133
-
134
- # down sweep (first 2 steps unfolded)
135
- Aa = A[:, :, 0:L:2**(num_steps-2)]
136
- Xa = X[:, :, 0:L:2**(num_steps-2)]
137
- Xa[:, :, 1].add_(Aa[:, :, 1].mul(Xa[:, :, 2]))
138
- Aa[:, :, 1].mul_(Aa[:, :, 2])
139
-
140
- for k in range(num_steps-3, -1, -1):
141
- Aa = A[:, :, 0:L:2**k]
142
- Xa = X[:, :, 0:L:2**k]
143
-
144
- T = Xa.size(2)
145
- Aa = Aa.view(B, D, T//2, 2, -1)
146
- Xa = Xa.view(B, D, T//2, 2, -1)
147
-
148
- Xa[:, :, :-1, 1].add_(Aa[:, :, :-1, 1].mul(Xa[:, :, 1:, 0]))
149
- Aa[:, :, :-1, 1].mul_(Aa[:, :, 1:, 0])
150
-
151
- @staticmethod
152
- def forward(ctx, A_in, X_in):
153
- """
154
- Applies the parallel scan operation, as defined above. Returns a new tensor.
155
- If you can, privilege sequence lengths that are powers of two.
156
-
157
- Args:
158
- A_in : (B, L, D, N)
159
- X_in : (B, L, D, N)
160
-
161
- Returns:
162
- H : (B, L, D, N)
163
- """
164
-
165
- L = X_in.size(1)
166
-
167
- # cloning is requiered because of the in-place ops
168
- if L == npo2(L):
169
- A = A_in.clone()
170
- X = X_in.clone()
171
- else:
172
- # pad tensors (and clone btw)
173
- A = pad_npo2(A_in) # (B, npo2(L), D, N)
174
- X = pad_npo2(X_in) # (B, npo2(L), D, N)
175
-
176
- # prepare tensors
177
- A = A.transpose(2, 1) # (B, D, npo2(L), N)
178
- X = X.transpose(2, 1) # (B, D, npo2(L), N)
179
-
180
- # parallel scan (modifies X in-place)
181
- PScan.pscan(A, X)
182
-
183
- ctx.save_for_backward(A_in, X)
184
-
185
- # slice [:, :L] (cut if there was padding)
186
- return X.transpose(2, 1)[:, :L]
187
-
188
- @staticmethod
189
- def backward(ctx, grad_output_in):
190
- """
191
- Flows the gradient from the output to the input. Returns two new tensors.
192
-
193
- Args:
194
- ctx : A_in : (B, L, D, N), X : (B, D, L, N)
195
- grad_output_in : (B, L, D, N)
196
-
197
- Returns:
198
- gradA : (B, L, D, N), gradX : (B, L, D, N)
199
- """
200
-
201
- A_in, X = ctx.saved_tensors
202
-
203
- L = grad_output_in.size(1)
204
-
205
- # cloning is requiered because of the in-place ops
206
- if L == npo2(L):
207
- grad_output = grad_output_in.clone()
208
- # the next padding will clone A_in
209
- else:
210
- grad_output = pad_npo2(grad_output_in) # (B, npo2(L), D, N)
211
- A_in = pad_npo2(A_in) # (B, npo2(L), D, N)
212
-
213
- # prepare tensors
214
- grad_output = grad_output.transpose(2, 1)
215
- A_in = A_in.transpose(2, 1) # (B, D, npo2(L), N)
216
- A = torch.nn.functional.pad(A_in[:, :, 1:], (0, 0, 0, 1)) # (B, D, npo2(L), N) shift 1 to the left (see hand derivation)
217
-
218
- # reverse parallel scan (modifies grad_output in-place)
219
- PScan.pscan_rev(A, grad_output)
220
-
221
- Q = torch.zeros_like(X)
222
- Q[:, :, 1:].add_(X[:, :, :-1] * grad_output[:, :, 1:])
223
-
224
- return Q.transpose(2, 1)[:, :L], grad_output.transpose(2, 1)[:, :L]
225
-
226
- pscan = PScan.apply