SmerkyG commited on
Commit
8a773e0
·
verified ·
1 Parent(s): 0880dfe

Update modeling_rwkv7.py

Browse files
Files changed (1) hide show
  1. modeling_rwkv7.py +874 -874
modeling_rwkv7.py CHANGED
@@ -1,874 +1,874 @@
1
- # coding=utf-8
2
- # Copyright 2024 The RWKV team and HuggingFace Inc. team.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- """PyTorch RWKV7 World model."""
16
-
17
- from dataclasses import dataclass
18
- from typing import List, Optional, Tuple, Union
19
-
20
- from pathlib import Path
21
-
22
- import math
23
- import torch
24
- import torch.nn.functional as F
25
- import torch.utils.checkpoint
26
- from torch import nn
27
- from torch.nn import CrossEntropyLoss
28
-
29
- from transformers.modeling_utils import PreTrainedModel, GenerationMixin, _init_weights
30
- from transformers.utils import (
31
- ModelOutput,
32
- add_code_sample_docstrings,
33
- add_start_docstrings,
34
- add_start_docstrings_to_model_forward,
35
- is_ninja_available,
36
- is_torch_cuda_available,
37
- logging,
38
- )
39
-
40
- from .configuration_rwkv7 import Rwkv7Config
41
-
42
- # MIT License
43
-
44
- # Copyright (c) 2024 Songlin Yang
45
-
46
- # Permission is hereby granted, free of charge, to any person obtaining a copy
47
- # of this software and associated documentation files (the "Software"), to deal
48
- # in the Software without restriction, including without limitation the rights
49
- # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
50
- # copies of the Software, and to permit persons to whom the Software is
51
- # furnished to do so, subject to the following conditions:
52
-
53
- # The above copyright notice and this permission notice shall be included in all
54
- # copies or substantial portions of the Software.
55
-
56
- # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
57
- # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
58
- # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
59
- # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
60
- # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
61
- # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
62
- # SOFTWARE.
63
-
64
- # Copyright (c) 2024, Johan Sokrates Wind
65
-
66
- import torch as th
67
- import triton
68
- import triton.language as tl
69
-
70
- @triton.jit
71
- def IND4(a,b,c,d,nb,nc,nd):
72
- return ((a*nb+b)*nc+c)*nd+d
73
- @triton.jit
74
- def IND5(a,b,c,d,e,nb,nc,nd,ne):
75
- return (((a*nb+b)*nc+c)*nd+d)*ne+e
76
-
77
- @triton.jit
78
- def _prod(a,b): return a*b
79
-
80
- # inv(I-A) where A is a strictly lower triangular nxn matrix
81
- @triton.jit
82
- def tri_minv(A, n:tl.constexpr, prec:tl.constexpr):
83
- i = tl.arange(0,n)
84
- prod = (i[None,:]==i[:,None]).to(tl.float32)
85
- for j in range(n-1):
86
- prod += tl_dot(prec, prod, (A*((i[None,:]==j)*(i[:,None]>i[None,:]))).trans())
87
- return prod.trans()
88
-
89
- @triton.jit
90
- def fw_attn_triton(w_,q_,k_,v_,a_,b_, s0_,y_,s_,sT_, B:tl.constexpr,T:tl.constexpr,H:tl.constexpr,C:tl.constexpr,dT:tl.constexpr, prec:tl.constexpr):
91
- bi = tl.program_id(1)
92
- hi = tl.program_id(0)
93
-
94
- i = tl.arange(0,C)[None,:]
95
- state = tl.load(s0_+IND4(bi,hi,i.trans(),i, H,C,C)).to(tl.float32)
96
- for t0 in range(T//dT):
97
- t = t0*dT+tl.arange(0,dT)[:,None]
98
- sw = tl.load(w_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32)
99
- sq = tl.load(q_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32)
100
- sk = tl.load(k_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32)
101
- sv = tl.load(v_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32)
102
- sa = tl.load(a_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32)
103
- sb = tl.load(b_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32)
104
-
105
- w = (-sw.exp()).exp()
106
- fw = tl.reduce(w, 0, _prod, keep_dims=True)
107
- incl_pref = tl.cumprod(w,axis=0)
108
- non_incl_pref = incl_pref / w
109
- inv_incl_pref = 1 / incl_pref
110
-
111
- wq = sq * incl_pref
112
- wa = sa * non_incl_pref
113
- kwi = sk * inv_incl_pref
114
- bwi = sb * inv_incl_pref
115
-
116
- mask1 = (t > t.trans())
117
- ab = tl_dot(prec, wa, bwi.trans()) * mask1
118
- ak = tl_dot(prec, wa, kwi.trans()) * mask1
119
-
120
- ab_inv = tri_minv(ab, dT, prec)
121
-
122
- ab_u = tl_dot(prec, ak, sv) + tl_dot(prec, wa, state.trans())
123
- u = tl_dot(prec, ab_inv, ab_u)
124
- mask2 = (t >= t.trans())
125
- qk = tl_dot(prec, wq, kwi.trans()) * mask2
126
- qb = tl_dot(prec, wq, bwi.trans()) * mask2
127
- yy = tl_dot(prec, qk, sv) + tl_dot(prec, qb, u) + tl_dot(prec, wq, state.trans())
128
- tl.store(y_+IND4(bi,t,hi,i, T,H,C), yy.to(tl.bfloat16))
129
-
130
- tl.store(s_+IND5(bi,hi,t0,i.trans(),i, H,T//dT,C,C), state.to(tl.float32))
131
- state = state * fw + tl_dot(prec, sv.trans(), kwi*fw) + tl_dot(prec, u.trans(), bwi*fw)
132
- tl.store(sT_+IND4(bi,hi,i.trans(),i, H,C,C), state.to(tl.bfloat16))
133
-
134
- @triton.jit
135
- def bw_attn_triton(w_,q_,k_,v_,a_,b_, dy_,s_,dsT_, dw_,dq_,dk_,dv_,da_,db_,ds0_, B:tl.constexpr,T:tl.constexpr,H:tl.constexpr,C:tl.constexpr,dT:tl.constexpr, prec:tl.constexpr):
136
- bi = tl.program_id(1)
137
- hi = tl.program_id(0)
138
-
139
- i = tl.arange(0,C)[None,:]
140
- dstate = tl.load(dsT_+IND4(bi,hi,i.trans(),i, H,C,C)).to(tl.float32)
141
-
142
- for t0 in range(T//dT-1,-1,-1):
143
- t = t0*dT+tl.arange(0,dT)[:,None]
144
-
145
- state = tl.load(s_+IND5(bi,hi,t0,i.trans(),i, H,T//dT,C,C)).to(tl.float32)
146
-
147
- sw = tl.load(w_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32)
148
- sq = tl.load(q_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32)
149
- sk = tl.load(k_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32)
150
- sv = tl.load(v_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32)
151
- sa = tl.load(a_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32)
152
- sb = tl.load(b_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32)
153
- sdy = tl.load(dy_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32)
154
-
155
- dw_fac = -sw.exp()
156
- w = dw_fac.exp()
157
- fw = tl.reduce(w, 0, _prod, keep_dims=True)
158
- incl_pref = tl.cumprod(w,axis=0)
159
- non_incl_pref = incl_pref / w
160
- inv_incl_pref = 1 / incl_pref
161
-
162
- wq = sq * incl_pref
163
- wa = sa * non_incl_pref
164
- kwi = sk * inv_incl_pref
165
- bwi = sb * inv_incl_pref
166
-
167
- mask1 = (t > t.trans())
168
- ab = tl_dot(prec, wa, bwi.trans()) * mask1
169
- ak = tl_dot(prec, wa, kwi.trans()) * mask1
170
-
171
- ab_inv = tri_minv(ab, dT, prec)
172
-
173
- ab_u = tl_dot(prec, ak, sv) + tl_dot(prec, wa, state.trans())
174
- u = tl_dot(prec, ab_inv, ab_u)
175
- mask2 = (t >= t.trans())
176
- qk = tl_dot(prec, wq, kwi.trans()) * mask2
177
- qb = tl_dot(prec, wq, bwi.trans()) * mask2
178
-
179
- du = tl_dot(prec, qb.trans(), sdy) + tl_dot(prec, bwi*fw, dstate.trans())
180
- dab_u = tl_dot(prec, ab_inv.trans(), du)
181
-
182
- dv = tl_dot(prec, qk.trans(), sdy) + tl_dot(prec, kwi*fw, dstate.trans()) + tl_dot(prec, ak.trans(), dab_u)
183
- tl.store(dv_+IND4(bi,t,hi,i, T,H,C), dv.to(tl.bfloat16))
184
-
185
- dab = tl_dot(prec, tl_dot(prec, ab_inv.trans(), du), u.trans()) * mask1
186
- dak = tl_dot(prec, dab_u, sv.trans()) * mask1
187
- dab_u_state = tl_dot(prec, dab_u, state)
188
- da = non_incl_pref * (tl_dot(prec, dab, bwi) + tl_dot(prec, dak, kwi) + dab_u_state)
189
- tl.store(da_+IND4(bi,t,hi,i, T,H,C), da.to(tl.bfloat16))
190
-
191
- dqb = tl_dot(prec, sdy, u.trans()) * mask2
192
- dqk = tl_dot(prec, sdy, sv.trans()) * mask2
193
- dy_state = tl_dot(prec, sdy, state)
194
- dq = incl_pref * (tl_dot(prec, dqb, bwi) + tl_dot(prec, dqk, kwi) + dy_state)
195
- tl.store(dq_+IND4(bi,t,hi,i, T,H,C), dq.to(tl.bfloat16))
196
-
197
- fw_u_dstate = fw * tl_dot(prec, u, dstate)
198
- db = inv_incl_pref * (tl_dot(prec, dab.trans(), wa) + tl_dot(prec, dqb.trans(), wq) + fw_u_dstate)
199
- tl.store(db_+IND4(bi,t,hi,i, T,H,C), db.to(tl.bfloat16))
200
-
201
- fw_v_dstate = fw * tl_dot(prec, sv, dstate)
202
- dk = inv_incl_pref * (tl_dot(prec, dak.trans(), wa) + tl_dot(prec, dqk.trans(), wq) + fw_v_dstate)
203
- tl.store(dk_+IND4(bi,t,hi,i, T,H,C), dk.to(tl.bfloat16))
204
-
205
- dw0 = fw * tl.sum(state*dstate, axis=0,keep_dims=True)
206
- for k in range(t0*dT,t0*dT+dT):
207
- lmask = (t<k).trans()
208
- A = (tl_dot(prec, dab*lmask, bwi) + tl_dot(prec, dak*lmask, kwi)) * wa * (t>k)
209
- A += (tl_dot(prec, dqb*lmask, bwi) + tl_dot(prec, dqk*lmask, kwi)) * wq * (t>=k)
210
- A += (fw_v_dstate*kwi + fw_u_dstate*bwi) * (t<k)
211
- A += dab_u_state*wa * (t>k) + dy_state*wq * (t>=k)
212
- dw = tl.sum(A, axis=0,keep_dims=True) + dw0
213
-
214
- wk = tl.load(w_+IND4(bi,k,hi,i, T,H,C)).to(tl.float32)
215
- dw *= -wk.exp()
216
- tl.store(dw_+IND4(bi,k,hi,i, T,H,C), dw.to(tl.bfloat16))
217
-
218
- dstate = dstate * fw + tl_dot(prec, sdy.trans(), wq) + tl_dot(prec, dab_u.trans(), wa)
219
- tl.store(ds0_+IND4(bi,hi,i.trans(),i, H,C,C), dstate.to(tl.bfloat16))
220
-
221
-
222
- class TritonRWKV7(th.autograd.Function):
223
- @staticmethod
224
- def forward(ctx, w,q,k,v,z,b,s0, dot_prec):
225
- K = 16
226
- B,T,H,C = w.shape
227
- s0 = th.zeros(B,H,C,C, dtype=w.dtype,device=w.device) if s0 is None else s0
228
- y = th.empty_like(v)
229
- sT = th.empty_like(s0)
230
- s = th.zeros(B,H,T//K,C,C, dtype=th.float32,device=w.device)
231
- fw_attn_triton[(H,B)](w,q,k,v,z,b, s0,y,s,sT, B,T,H,C,K, dot_prec)
232
- ctx.dot_prec = dot_prec
233
- ctx.save_for_backward(w,q,k,v,z,b,s)
234
- return y, sT
235
- @staticmethod
236
- def backward(ctx, dy, dsT):
237
- K = 16
238
- w,q,k,v,z,b,s = ctx.saved_tensors
239
- B,T,H,C = w.shape
240
- dw,dq,dk,dv,dz,db,ds0 = [th.empty_like(x) for x in [w,q,k,v,z,b,dsT]]
241
- bw_attn_triton[(H,B)](w,q,k,v,z,b, dy,s,dsT, dw,dq,dk,dv,dz,db,ds0, B,T,H,C,K, ctx.dot_prec)
242
- return dw,dq,dk,dv,dz,db,ds0,None
243
-
244
- @triton.jit
245
- def tl_dot(prec:tl.constexpr, a, b) -> torch.Tensor:
246
- if prec == 'fp32':
247
- return tl.dot(a.to(tl.float32),b.trans().to(tl.float32).trans(), allow_tf32=False)
248
- elif prec == 'tf32':
249
- return tl.dot(a.to(tl.float32),b.trans().to(tl.float32).trans(), allow_tf32=True)
250
- elif prec == 'bf16':
251
- return tl.dot(a.to(tl.bfloat16),b.trans().to(tl.bfloat16).trans(), allow_tf32=True)
252
- else:
253
- tl.static_assert(False)
254
-
255
- def rwkv7_attn_triton(r,w,k,v,a,b, HEAD_SIZE, dot_prec = 'fp32'):
256
- B,T,HC = w.shape
257
- C = HEAD_SIZE
258
- H = HC//C
259
- r,w,k,v,a,b = [i.view(B,T,H,C) for i in [r,w,k,v,a,b]]
260
- s0 = th.zeros(B,H,C,C, dtype=th.bfloat16,device=w.device)
261
- return TritonRWKV7.apply(w,r,k,v,a,b,s0,dot_prec)[0].view(B,T,HC)
262
-
263
- logger = logging.get_logger(__name__)
264
-
265
- _CHECKPOINT_FOR_DOC = "RWKV/v7-Goose-1.6B-Pile-HF"
266
- _CONFIG_FOR_DOC = "Rwkv7Config"
267
-
268
- class Rwkv7SelfAttention(nn.Module):
269
- def __init__(self, config, layer_id=0):
270
- super().__init__()
271
- self.config = config
272
- self.layer_id = layer_id
273
- C = hidden_size = config.hidden_size
274
- attention_hidden_size = config.attention_hidden_size
275
- self.attention_hidden_size = attention_hidden_size
276
- H = self.num_heads = attention_hidden_size // config.head_size
277
- N = self.head_size = config.head_size
278
-
279
- calc_lora_rank = lambda exponent, multiplier: max(1, round(hidden_size ** exponent * multiplier / 32)) * 32
280
- lora_rank_decay = config.lora_rank_decay or calc_lora_rank(0.5, 1.8)
281
- lora_rank_iclr = config.lora_rank_iclr or calc_lora_rank(0.5, 1.8)
282
- lora_rank_value_residual_mix = config.lora_rank_value_residual_mix or calc_lora_rank(0.5, 1.3)
283
- lora_rank_gate = config.lora_rank_gate or calc_lora_rank(0.8, 0.6)
284
-
285
- self.x_r = nn.Parameter(torch.empty(1,1,C))
286
- self.x_w = nn.Parameter(torch.empty(1,1,C))
287
- self.x_k = nn.Parameter(torch.empty(1,1,C))
288
- self.x_v = nn.Parameter(torch.empty(1,1,C))
289
- self.x_a = nn.Parameter(torch.empty(1,1,C))
290
- self.x_g = nn.Parameter(torch.empty(1,1,C))
291
-
292
- self.w0 = nn.Parameter(torch.empty(1,1,C))
293
- self.w1 = nn.Parameter(torch.empty(C, lora_rank_decay))
294
- self.w2 = nn.Parameter(torch.empty(lora_rank_decay, C))
295
-
296
- self.a0 = nn.Parameter(torch.empty(1,1,C))
297
- self.a1 = nn.Parameter(torch.empty(C, lora_rank_iclr))
298
- self.a2 = nn.Parameter(torch.empty(lora_rank_iclr, C))
299
-
300
- if layer_id > 0:
301
- self.v0 = nn.Parameter(torch.empty(1,1,C))
302
- self.v1 = nn.Parameter(torch.empty(C, lora_rank_value_residual_mix))
303
- self.v2 = nn.Parameter(torch.empty(lora_rank_value_residual_mix, C))
304
-
305
- self.g1 = nn.Parameter(torch.empty(C, lora_rank_gate))
306
- self.g2 = nn.Parameter(torch.empty(lora_rank_gate, C))
307
-
308
- self.k_k = nn.Parameter(torch.empty(1,1,C))
309
- self.k_a = nn.Parameter(torch.empty(1,1,C))
310
- self.r_k = nn.Parameter(torch.empty(H,N))
311
-
312
- self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
313
- self.receptance = nn.Linear(C, C, bias=False)
314
- self.key = nn.Linear(C, C, bias=False)
315
- self.value = nn.Linear(C, C, bias=False)
316
- self.output = nn.Linear(C, C, bias=False)
317
- self.ln_x = nn.GroupNorm(H, C, eps=self.head_size * 1e-5)
318
-
319
-
320
- def forward(self, hidden, state=None, v_first=None, use_cache=False, seq_mode=True):
321
- # Mix hidden with the previous timestep to produce key, value, receptance
322
- if hidden.size(1) == 1 and state is not None:
323
- shifted = state[0][self.layer_id]
324
- else:
325
- shifted = self.time_shift(hidden)
326
- if state is not None:
327
- shifted[:, 0] = state[0][self.layer_id]
328
- if len(shifted.size()) == 2:
329
- shifted = shifted.unsqueeze(1)
330
-
331
- x = hidden
332
-
333
- B, T, C = hidden.shape
334
- H = self.num_heads
335
- N = self.head_size
336
-
337
- xx = shifted - x
338
-
339
- xr = x+xx*self.x_r
340
- xw = x+xx*self.x_w
341
- xk = x+xx*self.x_k
342
- xv = x+xx*self.x_v
343
- xa = x+xx*self.x_a
344
- xg = x+xx*self.x_g
345
-
346
- r = self.receptance(xr)
347
- w = torch.tanh(xw @ self.w1) @ self.w2
348
- k = self.key(xk)
349
- v = self.value(xv)
350
- a = torch.sigmoid(self.a0 + (xa @ self.a1) @ self.a2)
351
- g = torch.sigmoid(xg @ self.g1) @ self.g2
352
-
353
- kk = torch.nn.functional.normalize((k * self.k_k).view(B,T,H,-1), dim=-1, p=2.0).view(B,T,-1)
354
- k = k * (1 + (a-1) * self.k_a)
355
- if self.layer_id == 0: v_first = v
356
- else: v = v + (v_first - v) * torch.sigmoid(self.v0 + (xv @ self.v1) @ self.v2)
357
-
358
- if T == 1 or not self.training:
359
- w = torch.exp(-0.606531 * torch.sigmoid((self.w0 + w).float())) # 0.606531 = exp(-0.5)
360
- vk_state = state[1][self.layer_id]
361
- for t in range(T):
362
- r_, w_, k_, v_, kk_, a_ = r[:,t], w[:,t], k[:,t], v[:,t], kk[:,t], a[:,t]
363
- vk = v_.view(B,H,N,1) @ k_.view(B,H,1,N)
364
- ab = (-kk_).view(B,H,N,1) @ (kk_*a_).view(B,H,1,N)
365
- vk_state = vk_state * w_.view(B,H,1,N) + vk_state @ ab.float() + vk.float()
366
- xx[:,t] = (vk_state.to(dtype=x.dtype) @ r_.view(B,H,N,1)).view(B,H*N)
367
- state[1][self.layer_id] = vk_state
368
- # FIXME - support fast triton kernel for non-training pre-fill with state in and out
369
- else:
370
- w = -torch.nn.functional.softplus(-(self.w0 + w)) - 0.5
371
- rwkv7_attn_triton(r, w, k, v, -kk, kk*a, self.head_size)
372
-
373
- xx = torch.nn.functional.group_norm(xx.view(B*T,H*N), num_groups=H, weight=self.ln_x.weight, bias=self.ln_x.bias, eps = self.ln_x.eps).view(B,T,H*N)
374
- xx = xx + ((r.view(B,T,H,-1)*k.view(B,T,H,-1)*self.r_k).sum(dim=-1, keepdim=True) * v.view(B,T,H,-1)).view(B,T,C)
375
- xx = self.output(xx * g)
376
-
377
- if state is not None:
378
- state[0][self.layer_id] = hidden[:, -1]
379
-
380
- return xx, state, v_first
381
-
382
-
383
- class Rwkv7FeedForward(nn.Module):
384
- def __init__(self, config, layer_id=0):
385
- super().__init__()
386
- self.config = config
387
- self.layer_id = layer_id
388
- hidden_size = config.hidden_size
389
- intermediate_size = (
390
- config.intermediate_size
391
- if config.intermediate_size is not None
392
- else int(config.hidden_size * 4)
393
- )
394
-
395
-
396
- self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
397
-
398
- self.x_k = nn.Parameter(torch.empty(1, 1, hidden_size))
399
-
400
- self.key = nn.Linear(hidden_size, intermediate_size, bias=False)
401
- self.value = nn.Linear(intermediate_size, hidden_size, bias=False)
402
-
403
- def forward(self, hidden, state=None):
404
- if hidden.size(1) == 1 and state is not None:
405
- shifted = state[2][self.layer_id]
406
- else:
407
- shifted = self.time_shift(hidden)
408
- if state is not None:
409
- shifted[:, 0] = state[2][self.layer_id]
410
- if len(shifted.size()) == 2:
411
- shifted = shifted.unsqueeze(1)
412
-
413
- delta_hidden_to_shifted = shifted - hidden
414
- key = hidden + delta_hidden_to_shifted * self.x_k
415
-
416
- key = torch.square(torch.relu(self.key(key)))
417
- value = self.value(key)
418
-
419
- if state is not None:
420
- state[2][self.layer_id] = hidden[:, -1]
421
-
422
- return value, state
423
-
424
-
425
- class Rwkv7Block(nn.Module):
426
- def __init__(self, config, layer_id):
427
- super().__init__()
428
- self.config = config
429
- self.layer_id = layer_id
430
-
431
- self.ln1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
432
- self.ln2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
433
-
434
- self.attention = Rwkv7SelfAttention(config, layer_id)
435
- self.feed_forward = Rwkv7FeedForward(config, layer_id)
436
-
437
- def forward(self, hidden, state=None, v_first=None, use_cache=False, output_attentions=False, seq_mode=True):
438
- attention, state, v_first = self.attention(self.ln1(hidden), state=state, v_first=v_first, use_cache=use_cache, seq_mode=seq_mode)
439
- hidden = hidden + attention
440
-
441
- feed_forward, state = self.feed_forward(self.ln2(hidden), state=state)
442
- hidden = hidden + feed_forward
443
-
444
- outputs = (hidden, state, v_first)
445
- if output_attentions:
446
- outputs += (attention,)
447
- else:
448
- outputs += (None,)
449
-
450
- return outputs
451
-
452
-
453
- class Rwkv7PreTrainedModel(PreTrainedModel):
454
- """
455
- An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
456
- models.
457
- """
458
-
459
- config_class = Rwkv7Config
460
- base_model_prefix = "rwkv7"
461
- _no_split_modules = ["Rwkv7Block"]
462
- _keep_in_fp32_modules = []
463
- supports_gradient_checkpointing = True
464
-
465
- def _init_weights(self, module):
466
- return
467
-
468
- """Initialize the weights."""
469
- if isinstance(module, Rwkv7SelfAttention):
470
- layer_id = module.layer_id
471
- num_hidden_layers = module.config.num_hidden_layers
472
- hidden_size = module.config.hidden_size
473
- attention_hidden_size = module.attention_hidden_size
474
- head_size = module.config.head_size
475
- num_heads = attention_hidden_size // head_size
476
-
477
- ratio_0_to_1 = layer_id / (num_hidden_layers - 1) # 0 to 1
478
- ratio_1_to_almost0 = 1.0 - (layer_id / num_hidden_layers) # 1 to ~0
479
-
480
- time_weight = torch.tensor(
481
- [i / hidden_size for i in range(hidden_size)],
482
- dtype=module.x_k.dtype,
483
- device=module.x_k.device,
484
- )
485
- time_weight = time_weight[None, None, :]
486
-
487
- decay_speed = [
488
- -7.0 + 5.0 * (n / (attention_hidden_size - 1)) ** (0.85 + 1.0 * ratio_0_to_1 ** 0.5)
489
- for n in range(attention_hidden_size)
490
- ]
491
- decay_speed = torch.tensor(decay_speed, dtype=module.w0.dtype, device=module.w0.device)
492
-
493
- with torch.no_grad():
494
- module.x_r.copy_( 1.0 - torch.pow(time_weight, 0.2 * ratio_1_to_almost0) )
495
- module.x_w.copy_( 1.0 - torch.pow(time_weight, 0.9 * ratio_1_to_almost0) )
496
- module.x_k.copy_( 1.0 - (torch.pow(time_weight, 0.9 * ratio_1_to_almost0) + 0.4 * ratio_0_to_1) )
497
- module.x_v.copy_( 1.0 - (torch.pow(time_weight, 0.4 * ratio_1_to_almost0) + 0.6 * ratio_0_to_1) )
498
- module.x_a.copy_( 1.0 - torch.pow(time_weight, 0.9 * ratio_1_to_almost0) )
499
- module.x_g.copy_( 1.0 - torch.pow(time_weight, 0.2 * ratio_1_to_almost0) )
500
-
501
- def ortho_init(x, scale):
502
- with torch.no_grad():
503
- shape = x.shape
504
- if len(shape) == 2:
505
- gain = math.sqrt(shape[0] / shape[1]) if shape[0] > shape[1] else 1
506
- nn.init.orthogonal_(x, gain=gain * scale)
507
- elif len(shape) == 3:
508
- gain = math.sqrt(shape[1] / shape[2]) if shape[1] > shape[2] else 1
509
- for i in range(shape[0]):
510
- nn.init.orthogonal_(x[i], gain=gain * scale)
511
- else:
512
- assert False
513
- return x
514
-
515
- module.w0.copy_(decay_speed.reshape(1,1,attention_hidden_size) + 0.5) # !!! 0.5 comes from F.softplus !!!
516
- module.w1.zero_()
517
- ortho_init(module.w2, 0.1)
518
-
519
- module.a0.zero_()
520
- module.a1.zero_()
521
- ortho_init(module.a2, 0.1)
522
-
523
- module.v0.copy_(1.0)
524
- module.v1.zero_()
525
- ortho_init(module.v2, 0.1)
526
-
527
- module.g1.zero_()
528
- ortho_init(module.g2, 0.1)
529
-
530
- self.k_k.copy_(0.85)
531
- self.k_a.copy_(1.0)
532
- self.r_k.zero_()
533
-
534
- module.receptance.weight.data.uniform_(-0.5/(hidden_size**0.5), 0.5/(attention_hidden_size**0.5))
535
- module.key.weight.data.uniform_(-0.05/(hidden_size**0.5), 0.05/(attention_hidden_size**0.5))
536
- module.value.weight.data.uniform_(-0.5/(hidden_size**0.5), 0.5/(attention_hidden_size**0.5))
537
- module.output.weight.data.zero_()
538
-
539
- elif isinstance(module, Rwkv7FeedForward):
540
- layer_id = module.layer_id
541
- num_hidden_layers = module.config.num_hidden_layers
542
- hidden_size = module.config.hidden_size
543
-
544
- ratio_1_to_almost0 = 1.0 - (layer_id / num_hidden_layers) # 1 to ~0
545
-
546
- time_weight = torch.tensor(
547
- [i / hidden_size for i in range(hidden_size)],
548
- dtype=module.x_k.dtype,
549
- device=module.x_k.device,
550
- )
551
- time_weight = time_weight[None, None, :]
552
-
553
- with torch.no_grad():
554
- module.x_k.copy_( 1.0 - torch.pow(time_weight, ratio_1_to_almost0**4) )
555
-
556
- self.key.weight.data.uniform_(-0.5/(hidden_size**0.5), 0.5/(hidden_size**0.5))
557
- self.value.weight.data.zero_()
558
-
559
- @dataclass
560
- class Rwkv7Output(ModelOutput):
561
- """
562
- Class for the RWKV model outputs.
563
- Args:
564
- last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
565
- Sequence of hidden-states at the output of the last layer of the model.
566
- state (list of five `torch.FloatTensor` of shape `(batch_size, hidden_size, num_hidden_layers)`):
567
- The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
568
- avoid providing the old `input_ids`.
569
- hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
570
- Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
571
- one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of
572
- the model at the output of each layer plus the optional initial embedding outputs.
573
- attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
574
- Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
575
- sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
576
- the self-attention heads.
577
- """
578
-
579
- last_hidden_state: torch.FloatTensor = None
580
- state: Optional[List[torch.FloatTensor]] = None
581
- hidden_states: Optional[Tuple[torch.FloatTensor]] = None
582
- attentions: Optional[Tuple[torch.FloatTensor]] = None
583
-
584
-
585
- @dataclass
586
- class Rwkv7CausalLMOutput(ModelOutput):
587
- """
588
- Base class for causal language model (or autoregressive) outputs.
589
- Args:
590
- loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
591
- Language modeling loss (for next-token prediction).
592
- logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
593
- Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
594
- state (list of five `torch.FloatTensor` of shape `(batch_size, hidden_size, num_hidden_layers)`):
595
- The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
596
- avoid providing the old `input_ids`.
597
- hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
598
- Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
599
- one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of
600
- the model at the output of each layer plus the optional initial embedding outputs.
601
- attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
602
- Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
603
- sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
604
- the self-attention heads.
605
- """
606
-
607
- loss: Optional[torch.FloatTensor] = None
608
- logits: torch.FloatTensor = None
609
- state: Optional[List[torch.FloatTensor]] = None
610
- hidden_states: Optional[Tuple[torch.FloatTensor]] = None
611
- attentions: Optional[Tuple[torch.FloatTensor]] = None
612
-
613
-
614
- RWKV7_START_DOCSTRING = r"""
615
- This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
616
- library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
617
- etc.) This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module)
618
- subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to
619
- general usage and behavior.
620
- Parameters:
621
- config ([`Rwkv7Config`]): Model configuration class with all the parameters of the model.
622
- Initializing with a config file does not load the weights associated with the model, only the
623
- configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
624
- """
625
-
626
- RWKV7_INPUTS_DOCSTRING = r"""
627
- Args:
628
- input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
629
- `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
630
- `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input
631
- sequence tokens in the vocabulary. If `past_key_values` is used, only `input_ids` that do not have their
632
- past calculated should be passed as `input_ids`. Indices can be obtained using [`AutoTokenizer`]. See
633
- [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input
634
- IDs?](../glossary#input-ids)
635
- inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
636
- Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
637
- is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
638
- model's internal embedding lookup matrix.
639
- state (tuple of five `torch.FloatTensor` of shape `(batch_size, hidden_size, num_hidden_layers)`, *optional*):
640
- If passed along, the model uses the previous state in all the blocks (which will give the output for the
641
- `input_ids` provided as if the model add `state_input_ids + input_ids` as context).
642
- use_cache (`bool`, *optional*):
643
- If set to `True`, the last state is returned and can be used to quickly generate the next logits.
644
- output_attentions (`bool`, *optional*):
645
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
646
- tensors for more detail.
647
- output_hidden_states (`bool`, *optional*):
648
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
649
- more detail.
650
- return_dict (`bool`, *optional*):
651
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
652
- """
653
-
654
-
655
- @add_start_docstrings(
656
- "The bare RWKV7 Model transformer outputting raw hidden-states without any specific head on top.",
657
- RWKV7_START_DOCSTRING,
658
- )
659
- class Rwkv7Model(Rwkv7PreTrainedModel):
660
- def __init__(self, config):
661
- super().__init__(config)
662
-
663
- self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
664
- self.pre_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
665
- self.blocks = nn.ModuleList([Rwkv7Block(config, layer_id=idx) for idx in range(config.num_hidden_layers)])
666
- self.ln_out = nn.LayerNorm(config.hidden_size)
667
-
668
- self.gradient_checkpointing = False
669
-
670
- # Initialize weights and apply final processing
671
- self.post_init()
672
-
673
- def get_input_embeddings(self):
674
- return self.embeddings
675
-
676
- def set_input_embeddings(self, new_embeddings):
677
- self.embeddings = new_embeddings
678
-
679
- @add_start_docstrings_to_model_forward(RWKV7_INPUTS_DOCSTRING)
680
- @add_code_sample_docstrings(
681
- checkpoint=_CHECKPOINT_FOR_DOC,
682
- output_type=Rwkv7Output,
683
- config_class=_CONFIG_FOR_DOC,
684
- )
685
- def forward(
686
- self,
687
- input_ids: Optional[torch.LongTensor] = None,
688
- attention_mask: Optional[torch.LongTensor] = None, # noqa
689
- inputs_embeds: Optional[torch.FloatTensor] = None,
690
- state: Optional[List[torch.FloatTensor]] = None,
691
- use_cache: Optional[bool] = None,
692
- output_attentions: Optional[bool] = None,
693
- output_hidden_states: Optional[bool] = None,
694
- return_dict: Optional[bool] = None,
695
- ) -> Union[Tuple, Rwkv7Output]:
696
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
697
- output_hidden_states = (
698
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
699
- )
700
- use_cache = use_cache if use_cache is not None else self.config.use_cache
701
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
702
-
703
- if input_ids is not None and inputs_embeds is not None:
704
- raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
705
- elif input_ids is None and inputs_embeds is None:
706
- raise ValueError("You have to specify either input_ids or inputs_embeds")
707
-
708
- if inputs_embeds is None:
709
- inputs_embeds = self.embeddings(input_ids)
710
-
711
- if state is None:
712
- state = []
713
- head_size = self.config.head_size
714
- num_heads = self.config.attention_hidden_size // head_size
715
- state_attn_x = torch.zeros(
716
- (self.config.num_hidden_layers, inputs_embeds.size(0), self.config.hidden_size),
717
- dtype=inputs_embeds.dtype,
718
- requires_grad=False,
719
- device=inputs_embeds.device,
720
- ).contiguous()
721
- state_attn_vk = torch.zeros(
722
- (
723
- self.config.num_hidden_layers,
724
- inputs_embeds.size(0),
725
- num_heads,
726
- head_size,
727
- head_size,
728
- ),
729
- dtype=torch.float32,
730
- requires_grad=False,
731
- device=inputs_embeds.device,
732
- ).contiguous()
733
- state_ffn_x = torch.zeros(
734
- (self.config.num_hidden_layers, inputs_embeds.size(0), self.config.hidden_size),
735
- dtype=inputs_embeds.dtype,
736
- requires_grad=False,
737
- device=inputs_embeds.device,
738
- ).contiguous()
739
- state.append(state_attn_x)
740
- state.append(state_attn_vk)
741
- state.append(state_ffn_x)
742
-
743
- seq_mode = inputs_embeds.shape[1] > 1
744
- hidden_states = self.pre_ln(inputs_embeds)
745
- v_first = None
746
-
747
- all_self_attentions = () if output_attentions else None
748
- all_hidden_states = () if output_hidden_states else None
749
- for idx, block in enumerate(self.blocks):
750
- hidden_states, state, v_first, attentions = block(
751
- hidden_states, state=state, v_first=v_first, use_cache=use_cache, output_attentions=output_attentions, seq_mode=seq_mode
752
- )
753
-
754
- if output_hidden_states:
755
- all_hidden_states = all_hidden_states + (hidden_states,)
756
-
757
- if output_attentions:
758
- all_self_attentions = all_self_attentions + (attentions,)
759
-
760
- hidden_states = self.ln_out(hidden_states)
761
-
762
- if output_hidden_states:
763
- all_hidden_states = all_hidden_states + (hidden_states,)
764
-
765
- if not return_dict:
766
- return (hidden_states, state, all_hidden_states, all_self_attentions)
767
-
768
- return Rwkv7Output(
769
- last_hidden_state=hidden_states,
770
- state=state,
771
- hidden_states=all_hidden_states, # None
772
- attentions=all_self_attentions, # None
773
- )
774
-
775
- # copied from HuggingFace https://github.com/huggingface/transformers/blob/main/src/transformers/models/rwkv/modeling_rwkv.py
776
- @add_start_docstrings(
777
- """
778
- The RWKV7 Model transformer with a language modeling head on top (linear layer with weights tied to the input
779
- embeddings).
780
- """,
781
- RWKV7_START_DOCSTRING,
782
- )
783
- class Rwkv7ForCausalLM(Rwkv7PreTrainedModel, GenerationMixin):
784
- _tied_weights_keys = ["head.weight"]
785
-
786
- def __init__(self, config):
787
- super().__init__(config)
788
- self.model = Rwkv7Model(config)
789
- self.head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
790
-
791
- # Initialize weights and apply final processing
792
- self.post_init()
793
-
794
- def get_output_embeddings(self):
795
- return self.head
796
-
797
- def set_output_embeddings(self, new_embeddings):
798
- self.head = new_embeddings
799
-
800
- def prepare_inputs_for_generation(self, input_ids, state=None, inputs_embeds=None, **kwargs):
801
- # only last token for inputs_ids if the state is passed along.
802
- if state is not None:
803
- input_ids = input_ids[:, -1].unsqueeze(-1)
804
-
805
- # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
806
- if inputs_embeds is not None and state is None:
807
- model_inputs = {"inputs_embeds": inputs_embeds}
808
- else:
809
- model_inputs = {"input_ids": input_ids}
810
-
811
- model_inputs["state"] = state
812
- return model_inputs
813
-
814
- @add_start_docstrings_to_model_forward(RWKV7_INPUTS_DOCSTRING)
815
- @add_code_sample_docstrings(
816
- checkpoint=_CHECKPOINT_FOR_DOC,
817
- output_type=Rwkv7CausalLMOutput,
818
- config_class=_CONFIG_FOR_DOC,
819
- )
820
- def forward(
821
- self,
822
- input_ids: Optional[torch.LongTensor] = None,
823
- attention_mask: Optional[torch.LongTensor] = None,
824
- inputs_embeds: Optional[torch.FloatTensor] = None,
825
- state: Optional[List[torch.FloatTensor]] = None,
826
- labels: Optional[torch.LongTensor] = None,
827
- use_cache: Optional[bool] = None,
828
- output_attentions: Optional[bool] = None,
829
- output_hidden_states: Optional[bool] = None,
830
- return_dict: Optional[bool] = None,
831
- ) -> Union[Tuple, Rwkv7CausalLMOutput]:
832
- r"""
833
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
834
- Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
835
- `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
836
- are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
837
- """
838
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
839
-
840
- outputs = self.model(
841
- input_ids,
842
- inputs_embeds=inputs_embeds,
843
- state=state,
844
- use_cache=use_cache,
845
- output_attentions=output_attentions,
846
- output_hidden_states=output_hidden_states,
847
- return_dict=return_dict,
848
- )
849
- hidden_states = outputs[0]
850
-
851
- logits = self.head(hidden_states)
852
-
853
- loss = None
854
- if labels is not None:
855
- # move labels to correct device to enable model parallelism
856
- labels = labels.to(logits.device)
857
- # Shift so that tokens < n predict n
858
- shift_logits = logits[..., :-1, :].contiguous()
859
- shift_labels = labels[..., 1:].contiguous()
860
- # Flatten the tokens
861
- loss_fct = CrossEntropyLoss()
862
- loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
863
-
864
- if not return_dict:
865
- output = (logits,) + outputs[1:]
866
- return ((loss,) + output) if loss is not None else output
867
-
868
- return Rwkv7CausalLMOutput(
869
- loss=loss,
870
- logits=logits,
871
- state=outputs.state,
872
- hidden_states=outputs.hidden_states,
873
- attentions=outputs.attentions,
874
- )
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The RWKV team and HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch RWKV7 World model."""
16
+
17
+ from dataclasses import dataclass
18
+ from typing import List, Optional, Tuple, Union
19
+
20
+ from pathlib import Path
21
+
22
+ import math
23
+ import torch
24
+ import torch.nn.functional as F
25
+ import torch.utils.checkpoint
26
+ from torch import nn
27
+ from torch.nn import CrossEntropyLoss
28
+
29
+ from transformers.modeling_utils import PreTrainedModel, GenerationMixin, _init_weights
30
+ from transformers.utils import (
31
+ ModelOutput,
32
+ add_code_sample_docstrings,
33
+ add_start_docstrings,
34
+ add_start_docstrings_to_model_forward,
35
+ is_ninja_available,
36
+ is_torch_cuda_available,
37
+ logging,
38
+ )
39
+
40
+ from .configuration_rwkv7 import Rwkv7Config
41
+
42
+ # MIT License
43
+
44
+ # Copyright (c) 2024 Songlin Yang
45
+
46
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
47
+ # of this software and associated documentation files (the "Software"), to deal
48
+ # in the Software without restriction, including without limitation the rights
49
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
50
+ # copies of the Software, and to permit persons to whom the Software is
51
+ # furnished to do so, subject to the following conditions:
52
+
53
+ # The above copyright notice and this permission notice shall be included in all
54
+ # copies or substantial portions of the Software.
55
+
56
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
57
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
58
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
59
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
60
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
61
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
62
+ # SOFTWARE.
63
+
64
+ # Copyright (c) 2024, Johan Sokrates Wind
65
+
66
+ import torch as th
67
+ import triton
68
+ import triton.language as tl
69
+
70
+ @triton.jit
71
+ def IND4(a,b,c,d,nb,nc,nd):
72
+ return ((a*nb+b)*nc+c)*nd+d
73
+ @triton.jit
74
+ def IND5(a,b,c,d,e,nb,nc,nd,ne):
75
+ return (((a*nb+b)*nc+c)*nd+d)*ne+e
76
+
77
+ @triton.jit
78
+ def _prod(a,b): return a*b
79
+
80
+ # inv(I-A) where A is a strictly lower triangular nxn matrix
81
+ @triton.jit
82
+ def tri_minv(A, n:tl.constexpr, prec:tl.constexpr):
83
+ i = tl.arange(0,n)
84
+ prod = (i[None,:]==i[:,None]).to(tl.float32)
85
+ for j in range(n-1):
86
+ prod += tl_dot(prec, prod, (A*((i[None,:]==j)*(i[:,None]>i[None,:]))).trans())
87
+ return prod.trans()
88
+
89
+ @triton.jit
90
+ def fw_attn_triton(w_,q_,k_,v_,a_,b_, s0_,y_,s_,sT_, B:tl.constexpr,T:tl.constexpr,H:tl.constexpr,C:tl.constexpr,dT:tl.constexpr, prec:tl.constexpr):
91
+ bi = tl.program_id(1)
92
+ hi = tl.program_id(0)
93
+
94
+ i = tl.arange(0,C)[None,:]
95
+ state = tl.load(s0_+IND4(bi,hi,i.trans(),i, H,C,C)).to(tl.float32)
96
+ for t0 in range(T//dT):
97
+ t = t0*dT+tl.arange(0,dT)[:,None]
98
+ sw = tl.load(w_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32)
99
+ sq = tl.load(q_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32)
100
+ sk = tl.load(k_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32)
101
+ sv = tl.load(v_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32)
102
+ sa = tl.load(a_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32)
103
+ sb = tl.load(b_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32)
104
+
105
+ w = (-sw.exp()).exp()
106
+ fw = tl.reduce(w, 0, _prod, keep_dims=True)
107
+ incl_pref = tl.cumprod(w,axis=0)
108
+ non_incl_pref = incl_pref / w
109
+ inv_incl_pref = 1 / incl_pref
110
+
111
+ wq = sq * incl_pref
112
+ wa = sa * non_incl_pref
113
+ kwi = sk * inv_incl_pref
114
+ bwi = sb * inv_incl_pref
115
+
116
+ mask1 = (t > t.trans())
117
+ ab = tl_dot(prec, wa, bwi.trans()) * mask1
118
+ ak = tl_dot(prec, wa, kwi.trans()) * mask1
119
+
120
+ ab_inv = tri_minv(ab, dT, prec)
121
+
122
+ ab_u = tl_dot(prec, ak, sv) + tl_dot(prec, wa, state.trans())
123
+ u = tl_dot(prec, ab_inv, ab_u)
124
+ mask2 = (t >= t.trans())
125
+ qk = tl_dot(prec, wq, kwi.trans()) * mask2
126
+ qb = tl_dot(prec, wq, bwi.trans()) * mask2
127
+ yy = tl_dot(prec, qk, sv) + tl_dot(prec, qb, u) + tl_dot(prec, wq, state.trans())
128
+ tl.store(y_+IND4(bi,t,hi,i, T,H,C), yy.to(tl.bfloat16))
129
+
130
+ tl.store(s_+IND5(bi,hi,t0,i.trans(),i, H,T//dT,C,C), state.to(tl.float32))
131
+ state = state * fw + tl_dot(prec, sv.trans(), kwi*fw) + tl_dot(prec, u.trans(), bwi*fw)
132
+ tl.store(sT_+IND4(bi,hi,i.trans(),i, H,C,C), state.to(tl.bfloat16))
133
+
134
+ @triton.jit
135
+ def bw_attn_triton(w_,q_,k_,v_,a_,b_, dy_,s_,dsT_, dw_,dq_,dk_,dv_,da_,db_,ds0_, B:tl.constexpr,T:tl.constexpr,H:tl.constexpr,C:tl.constexpr,dT:tl.constexpr, prec:tl.constexpr):
136
+ bi = tl.program_id(1)
137
+ hi = tl.program_id(0)
138
+
139
+ i = tl.arange(0,C)[None,:]
140
+ dstate = tl.load(dsT_+IND4(bi,hi,i.trans(),i, H,C,C)).to(tl.float32)
141
+
142
+ for t0 in range(T//dT-1,-1,-1):
143
+ t = t0*dT+tl.arange(0,dT)[:,None]
144
+
145
+ state = tl.load(s_+IND5(bi,hi,t0,i.trans(),i, H,T//dT,C,C)).to(tl.float32)
146
+
147
+ sw = tl.load(w_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32)
148
+ sq = tl.load(q_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32)
149
+ sk = tl.load(k_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32)
150
+ sv = tl.load(v_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32)
151
+ sa = tl.load(a_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32)
152
+ sb = tl.load(b_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32)
153
+ sdy = tl.load(dy_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32)
154
+
155
+ dw_fac = -sw.exp()
156
+ w = dw_fac.exp()
157
+ fw = tl.reduce(w, 0, _prod, keep_dims=True)
158
+ incl_pref = tl.cumprod(w,axis=0)
159
+ non_incl_pref = incl_pref / w
160
+ inv_incl_pref = 1 / incl_pref
161
+
162
+ wq = sq * incl_pref
163
+ wa = sa * non_incl_pref
164
+ kwi = sk * inv_incl_pref
165
+ bwi = sb * inv_incl_pref
166
+
167
+ mask1 = (t > t.trans())
168
+ ab = tl_dot(prec, wa, bwi.trans()) * mask1
169
+ ak = tl_dot(prec, wa, kwi.trans()) * mask1
170
+
171
+ ab_inv = tri_minv(ab, dT, prec)
172
+
173
+ ab_u = tl_dot(prec, ak, sv) + tl_dot(prec, wa, state.trans())
174
+ u = tl_dot(prec, ab_inv, ab_u)
175
+ mask2 = (t >= t.trans())
176
+ qk = tl_dot(prec, wq, kwi.trans()) * mask2
177
+ qb = tl_dot(prec, wq, bwi.trans()) * mask2
178
+
179
+ du = tl_dot(prec, qb.trans(), sdy) + tl_dot(prec, bwi*fw, dstate.trans())
180
+ dab_u = tl_dot(prec, ab_inv.trans(), du)
181
+
182
+ dv = tl_dot(prec, qk.trans(), sdy) + tl_dot(prec, kwi*fw, dstate.trans()) + tl_dot(prec, ak.trans(), dab_u)
183
+ tl.store(dv_+IND4(bi,t,hi,i, T,H,C), dv.to(tl.bfloat16))
184
+
185
+ dab = tl_dot(prec, tl_dot(prec, ab_inv.trans(), du), u.trans()) * mask1
186
+ dak = tl_dot(prec, dab_u, sv.trans()) * mask1
187
+ dab_u_state = tl_dot(prec, dab_u, state)
188
+ da = non_incl_pref * (tl_dot(prec, dab, bwi) + tl_dot(prec, dak, kwi) + dab_u_state)
189
+ tl.store(da_+IND4(bi,t,hi,i, T,H,C), da.to(tl.bfloat16))
190
+
191
+ dqb = tl_dot(prec, sdy, u.trans()) * mask2
192
+ dqk = tl_dot(prec, sdy, sv.trans()) * mask2
193
+ dy_state = tl_dot(prec, sdy, state)
194
+ dq = incl_pref * (tl_dot(prec, dqb, bwi) + tl_dot(prec, dqk, kwi) + dy_state)
195
+ tl.store(dq_+IND4(bi,t,hi,i, T,H,C), dq.to(tl.bfloat16))
196
+
197
+ fw_u_dstate = fw * tl_dot(prec, u, dstate)
198
+ db = inv_incl_pref * (tl_dot(prec, dab.trans(), wa) + tl_dot(prec, dqb.trans(), wq) + fw_u_dstate)
199
+ tl.store(db_+IND4(bi,t,hi,i, T,H,C), db.to(tl.bfloat16))
200
+
201
+ fw_v_dstate = fw * tl_dot(prec, sv, dstate)
202
+ dk = inv_incl_pref * (tl_dot(prec, dak.trans(), wa) + tl_dot(prec, dqk.trans(), wq) + fw_v_dstate)
203
+ tl.store(dk_+IND4(bi,t,hi,i, T,H,C), dk.to(tl.bfloat16))
204
+
205
+ dw0 = fw * tl.sum(state*dstate, axis=0,keep_dims=True)
206
+ for k in range(t0*dT,t0*dT+dT):
207
+ lmask = (t<k).trans()
208
+ A = (tl_dot(prec, dab*lmask, bwi) + tl_dot(prec, dak*lmask, kwi)) * wa * (t>k)
209
+ A += (tl_dot(prec, dqb*lmask, bwi) + tl_dot(prec, dqk*lmask, kwi)) * wq * (t>=k)
210
+ A += (fw_v_dstate*kwi + fw_u_dstate*bwi) * (t<k)
211
+ A += dab_u_state*wa * (t>k) + dy_state*wq * (t>=k)
212
+ dw = tl.sum(A, axis=0,keep_dims=True) + dw0
213
+
214
+ wk = tl.load(w_+IND4(bi,k,hi,i, T,H,C)).to(tl.float32)
215
+ dw *= -wk.exp()
216
+ tl.store(dw_+IND4(bi,k,hi,i, T,H,C), dw.to(tl.bfloat16))
217
+
218
+ dstate = dstate * fw + tl_dot(prec, sdy.trans(), wq) + tl_dot(prec, dab_u.trans(), wa)
219
+ tl.store(ds0_+IND4(bi,hi,i.trans(),i, H,C,C), dstate.to(tl.bfloat16))
220
+
221
+
222
+ class TritonRWKV7(th.autograd.Function):
223
+ @staticmethod
224
+ def forward(ctx, w,q,k,v,z,b,s0, dot_prec):
225
+ K = 16
226
+ B,T,H,C = w.shape
227
+ s0 = th.zeros(B,H,C,C, dtype=w.dtype,device=w.device) if s0 is None else s0
228
+ y = th.empty_like(v)
229
+ sT = th.empty_like(s0)
230
+ s = th.zeros(B,H,T//K,C,C, dtype=th.float32,device=w.device)
231
+ fw_attn_triton[(H,B)](w,q,k,v,z,b, s0,y,s,sT, B,T,H,C,K, dot_prec)
232
+ ctx.dot_prec = dot_prec
233
+ ctx.save_for_backward(w,q,k,v,z,b,s)
234
+ return y, sT
235
+ @staticmethod
236
+ def backward(ctx, dy, dsT):
237
+ K = 16
238
+ w,q,k,v,z,b,s = ctx.saved_tensors
239
+ B,T,H,C = w.shape
240
+ dw,dq,dk,dv,dz,db,ds0 = [th.empty_like(x) for x in [w,q,k,v,z,b,dsT]]
241
+ bw_attn_triton[(H,B)](w,q,k,v,z,b, dy,s,dsT, dw,dq,dk,dv,dz,db,ds0, B,T,H,C,K, ctx.dot_prec)
242
+ return dw,dq,dk,dv,dz,db,ds0,None
243
+
244
+ @triton.jit
245
+ def tl_dot(prec:tl.constexpr, a, b) -> torch.Tensor:
246
+ if prec == 'fp32':
247
+ return tl.dot(a.to(tl.float32),b.trans().to(tl.float32).trans(), allow_tf32=False)
248
+ elif prec == 'tf32':
249
+ return tl.dot(a.to(tl.float32),b.trans().to(tl.float32).trans(), allow_tf32=True)
250
+ elif prec == 'bf16':
251
+ return tl.dot(a.to(tl.bfloat16),b.trans().to(tl.bfloat16).trans(), allow_tf32=True)
252
+ else:
253
+ tl.static_assert(False)
254
+
255
+ def rwkv7_attn_triton(r,w,k,v,a,b, HEAD_SIZE, dot_prec = 'fp32'):
256
+ B,T,HC = w.shape
257
+ C = HEAD_SIZE
258
+ H = HC//C
259
+ r,w,k,v,a,b = [i.view(B,T,H,C) for i in [r,w,k,v,a,b]]
260
+ s0 = th.zeros(B,H,C,C, dtype=th.bfloat16,device=w.device)
261
+ return TritonRWKV7.apply(w,r,k,v,a,b,s0,dot_prec)[0].view(B,T,HC)
262
+
263
+ logger = logging.get_logger(__name__)
264
+
265
+ _CHECKPOINT_FOR_DOC = "RWKV/v7-Goose-1.6B-Pile-HF"
266
+ _CONFIG_FOR_DOC = "Rwkv7Config"
267
+
268
+ class Rwkv7SelfAttention(nn.Module):
269
+ def __init__(self, config, layer_id=0):
270
+ super().__init__()
271
+ self.config = config
272
+ self.layer_id = layer_id
273
+ C = hidden_size = config.hidden_size
274
+ attention_hidden_size = config.attention_hidden_size
275
+ self.attention_hidden_size = attention_hidden_size
276
+ H = self.num_heads = attention_hidden_size // config.head_size
277
+ N = self.head_size = config.head_size
278
+
279
+ calc_lora_rank = lambda exponent, multiplier: max(1, round(hidden_size ** exponent * multiplier / 32)) * 32
280
+ lora_rank_decay = config.lora_rank_decay or calc_lora_rank(0.5, 1.8)
281
+ lora_rank_iclr = config.lora_rank_iclr or calc_lora_rank(0.5, 1.8)
282
+ lora_rank_value_residual_mix = config.lora_rank_value_residual_mix or calc_lora_rank(0.5, 1.3)
283
+ lora_rank_gate = config.lora_rank_gate or calc_lora_rank(0.8, 0.6)
284
+
285
+ self.x_r = nn.Parameter(torch.empty(1,1,C))
286
+ self.x_w = nn.Parameter(torch.empty(1,1,C))
287
+ self.x_k = nn.Parameter(torch.empty(1,1,C))
288
+ self.x_v = nn.Parameter(torch.empty(1,1,C))
289
+ self.x_a = nn.Parameter(torch.empty(1,1,C))
290
+ self.x_g = nn.Parameter(torch.empty(1,1,C))
291
+
292
+ self.w0 = nn.Parameter(torch.empty(1,1,C))
293
+ self.w1 = nn.Parameter(torch.empty(C, lora_rank_decay))
294
+ self.w2 = nn.Parameter(torch.empty(lora_rank_decay, C))
295
+
296
+ self.a0 = nn.Parameter(torch.empty(1,1,C))
297
+ self.a1 = nn.Parameter(torch.empty(C, lora_rank_iclr))
298
+ self.a2 = nn.Parameter(torch.empty(lora_rank_iclr, C))
299
+
300
+ if layer_id > 0:
301
+ self.v0 = nn.Parameter(torch.empty(1,1,C))
302
+ self.v1 = nn.Parameter(torch.empty(C, lora_rank_value_residual_mix))
303
+ self.v2 = nn.Parameter(torch.empty(lora_rank_value_residual_mix, C))
304
+
305
+ self.g1 = nn.Parameter(torch.empty(C, lora_rank_gate))
306
+ self.g2 = nn.Parameter(torch.empty(lora_rank_gate, C))
307
+
308
+ self.k_k = nn.Parameter(torch.empty(1,1,C))
309
+ self.k_a = nn.Parameter(torch.empty(1,1,C))
310
+ self.r_k = nn.Parameter(torch.empty(H,N))
311
+
312
+ self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
313
+ self.receptance = nn.Linear(C, C, bias=False)
314
+ self.key = nn.Linear(C, C, bias=False)
315
+ self.value = nn.Linear(C, C, bias=False)
316
+ self.output = nn.Linear(C, C, bias=False)
317
+ self.ln_x = nn.GroupNorm(H, C, eps=self.head_size * 1e-5)
318
+
319
+
320
+ def forward(self, hidden, state=None, v_first=None, use_cache=False, seq_mode=True):
321
+ # Mix hidden with the previous timestep to produce key, value, receptance
322
+ if hidden.size(1) == 1 and state is not None:
323
+ shifted = state[0][self.layer_id]
324
+ else:
325
+ shifted = self.time_shift(hidden)
326
+ if state is not None:
327
+ shifted[:, 0] = state[0][self.layer_id]
328
+ if len(shifted.size()) == 2:
329
+ shifted = shifted.unsqueeze(1)
330
+
331
+ x = hidden
332
+
333
+ B, T, C = hidden.shape
334
+ H = self.num_heads
335
+ N = self.head_size
336
+
337
+ xx = shifted - x
338
+
339
+ xr = x+xx*self.x_r
340
+ xw = x+xx*self.x_w
341
+ xk = x+xx*self.x_k
342
+ xv = x+xx*self.x_v
343
+ xa = x+xx*self.x_a
344
+ xg = x+xx*self.x_g
345
+
346
+ r = self.receptance(xr)
347
+ w = torch.tanh(xw @ self.w1) @ self.w2
348
+ k = self.key(xk)
349
+ v = self.value(xv)
350
+ a = torch.sigmoid(self.a0 + (xa @ self.a1) @ self.a2)
351
+ g = torch.sigmoid(xg @ self.g1) @ self.g2
352
+
353
+ kk = torch.nn.functional.normalize((k * self.k_k).view(B,T,H,-1), dim=-1, p=2.0).view(B,T,-1)
354
+ k = k * (1 + (a-1) * self.k_a)
355
+ if self.layer_id == 0: v_first = v
356
+ else: v = v + (v_first - v) * torch.sigmoid(self.v0 + (xv @ self.v1) @ self.v2)
357
+
358
+ if T == 1 or not self.training:
359
+ w = torch.exp(-0.606531 * torch.sigmoid((self.w0 + w).float())) # 0.606531 = exp(-0.5)
360
+ vk_state = state[1][self.layer_id]
361
+ for t in range(T):
362
+ r_, w_, k_, v_, kk_, a_ = r[:,t], w[:,t], k[:,t], v[:,t], kk[:,t], a[:,t]
363
+ vk = v_.view(B,H,N,1) @ k_.view(B,H,1,N)
364
+ ab = (-kk_).view(B,H,N,1) @ (kk_*a_).view(B,H,1,N)
365
+ vk_state = vk_state * w_.view(B,H,1,N) + vk_state @ ab.float() + vk.float()
366
+ xx[:,t] = (vk_state.to(dtype=x.dtype) @ r_.view(B,H,N,1)).view(B,H*N)
367
+ state[1][self.layer_id] = vk_state
368
+ # FIXME - support fast triton kernel for non-training pre-fill with state in and out
369
+ else:
370
+ w = -torch.nn.functional.softplus(-(self.w0 + w)) - 0.5
371
+ rwkv7_attn_triton(r, w, k, v, -kk, kk*a, self.head_size)
372
+
373
+ xx = torch.nn.functional.group_norm(xx.view(B*T,H*N), num_groups=H, weight=self.ln_x.weight, bias=self.ln_x.bias, eps = self.ln_x.eps).view(B,T,H*N)
374
+ xx = xx + ((r.view(B,T,H,-1)*k.view(B,T,H,-1)*self.r_k).sum(dim=-1, keepdim=True) * v.view(B,T,H,-1)).view(B,T,C)
375
+ xx = self.output(xx * g)
376
+
377
+ if state is not None:
378
+ state[0][self.layer_id] = hidden[:, -1]
379
+
380
+ return xx, state, v_first
381
+
382
+
383
+ class Rwkv7FeedForward(nn.Module):
384
+ def __init__(self, config, layer_id=0):
385
+ super().__init__()
386
+ self.config = config
387
+ self.layer_id = layer_id
388
+ hidden_size = config.hidden_size
389
+ intermediate_size = (
390
+ config.intermediate_size
391
+ if config.intermediate_size is not None
392
+ else int(config.hidden_size * 4)
393
+ )
394
+
395
+
396
+ self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
397
+
398
+ self.x_k = nn.Parameter(torch.empty(1, 1, hidden_size))
399
+
400
+ self.key = nn.Linear(hidden_size, intermediate_size, bias=False)
401
+ self.value = nn.Linear(intermediate_size, hidden_size, bias=False)
402
+
403
+ def forward(self, hidden, state=None):
404
+ if hidden.size(1) == 1 and state is not None:
405
+ shifted = state[2][self.layer_id]
406
+ else:
407
+ shifted = self.time_shift(hidden)
408
+ if state is not None:
409
+ shifted[:, 0] = state[2][self.layer_id]
410
+ if len(shifted.size()) == 2:
411
+ shifted = shifted.unsqueeze(1)
412
+
413
+ delta_hidden_to_shifted = shifted - hidden
414
+ key = hidden + delta_hidden_to_shifted * self.x_k
415
+
416
+ key = torch.square(torch.relu(self.key(key)))
417
+ value = self.value(key)
418
+
419
+ if state is not None:
420
+ state[2][self.layer_id] = hidden[:, -1]
421
+
422
+ return value, state
423
+
424
+
425
+ class Rwkv7Block(nn.Module):
426
+ def __init__(self, config, layer_id):
427
+ super().__init__()
428
+ self.config = config
429
+ self.layer_id = layer_id
430
+
431
+ self.ln1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
432
+ self.ln2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
433
+
434
+ self.attention = Rwkv7SelfAttention(config, layer_id)
435
+ self.feed_forward = Rwkv7FeedForward(config, layer_id)
436
+
437
+ def forward(self, hidden, state=None, v_first=None, use_cache=False, output_attentions=False, seq_mode=True):
438
+ attention, state, v_first = self.attention(self.ln1(hidden), state=state, v_first=v_first, use_cache=use_cache, seq_mode=seq_mode)
439
+ hidden = hidden + attention
440
+
441
+ feed_forward, state = self.feed_forward(self.ln2(hidden), state=state)
442
+ hidden = hidden + feed_forward
443
+
444
+ outputs = (hidden, state, v_first)
445
+ if output_attentions:
446
+ outputs += (attention,)
447
+ else:
448
+ outputs += (None,)
449
+
450
+ return outputs
451
+
452
+
453
+ class Rwkv7PreTrainedModel(PreTrainedModel):
454
+ """
455
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
456
+ models.
457
+ """
458
+
459
+ config_class = Rwkv7Config
460
+ base_model_prefix = "model"
461
+ _no_split_modules = ["Rwkv7Block"]
462
+ _keep_in_fp32_modules = []
463
+ supports_gradient_checkpointing = True
464
+
465
+ def _init_weights(self, module):
466
+ return
467
+
468
+ """Initialize the weights."""
469
+ if isinstance(module, Rwkv7SelfAttention):
470
+ layer_id = module.layer_id
471
+ num_hidden_layers = module.config.num_hidden_layers
472
+ hidden_size = module.config.hidden_size
473
+ attention_hidden_size = module.attention_hidden_size
474
+ head_size = module.config.head_size
475
+ num_heads = attention_hidden_size // head_size
476
+
477
+ ratio_0_to_1 = layer_id / (num_hidden_layers - 1) # 0 to 1
478
+ ratio_1_to_almost0 = 1.0 - (layer_id / num_hidden_layers) # 1 to ~0
479
+
480
+ time_weight = torch.tensor(
481
+ [i / hidden_size for i in range(hidden_size)],
482
+ dtype=module.x_k.dtype,
483
+ device=module.x_k.device,
484
+ )
485
+ time_weight = time_weight[None, None, :]
486
+
487
+ decay_speed = [
488
+ -7.0 + 5.0 * (n / (attention_hidden_size - 1)) ** (0.85 + 1.0 * ratio_0_to_1 ** 0.5)
489
+ for n in range(attention_hidden_size)
490
+ ]
491
+ decay_speed = torch.tensor(decay_speed, dtype=module.w0.dtype, device=module.w0.device)
492
+
493
+ with torch.no_grad():
494
+ module.x_r.copy_( 1.0 - torch.pow(time_weight, 0.2 * ratio_1_to_almost0) )
495
+ module.x_w.copy_( 1.0 - torch.pow(time_weight, 0.9 * ratio_1_to_almost0) )
496
+ module.x_k.copy_( 1.0 - (torch.pow(time_weight, 0.9 * ratio_1_to_almost0) + 0.4 * ratio_0_to_1) )
497
+ module.x_v.copy_( 1.0 - (torch.pow(time_weight, 0.4 * ratio_1_to_almost0) + 0.6 * ratio_0_to_1) )
498
+ module.x_a.copy_( 1.0 - torch.pow(time_weight, 0.9 * ratio_1_to_almost0) )
499
+ module.x_g.copy_( 1.0 - torch.pow(time_weight, 0.2 * ratio_1_to_almost0) )
500
+
501
+ def ortho_init(x, scale):
502
+ with torch.no_grad():
503
+ shape = x.shape
504
+ if len(shape) == 2:
505
+ gain = math.sqrt(shape[0] / shape[1]) if shape[0] > shape[1] else 1
506
+ nn.init.orthogonal_(x, gain=gain * scale)
507
+ elif len(shape) == 3:
508
+ gain = math.sqrt(shape[1] / shape[2]) if shape[1] > shape[2] else 1
509
+ for i in range(shape[0]):
510
+ nn.init.orthogonal_(x[i], gain=gain * scale)
511
+ else:
512
+ assert False
513
+ return x
514
+
515
+ module.w0.copy_(decay_speed.reshape(1,1,attention_hidden_size) + 0.5) # !!! 0.5 comes from F.softplus !!!
516
+ module.w1.zero_()
517
+ ortho_init(module.w2, 0.1)
518
+
519
+ module.a0.zero_()
520
+ module.a1.zero_()
521
+ ortho_init(module.a2, 0.1)
522
+
523
+ module.v0.copy_(1.0)
524
+ module.v1.zero_()
525
+ ortho_init(module.v2, 0.1)
526
+
527
+ module.g1.zero_()
528
+ ortho_init(module.g2, 0.1)
529
+
530
+ self.k_k.copy_(0.85)
531
+ self.k_a.copy_(1.0)
532
+ self.r_k.zero_()
533
+
534
+ module.receptance.weight.data.uniform_(-0.5/(hidden_size**0.5), 0.5/(attention_hidden_size**0.5))
535
+ module.key.weight.data.uniform_(-0.05/(hidden_size**0.5), 0.05/(attention_hidden_size**0.5))
536
+ module.value.weight.data.uniform_(-0.5/(hidden_size**0.5), 0.5/(attention_hidden_size**0.5))
537
+ module.output.weight.data.zero_()
538
+
539
+ elif isinstance(module, Rwkv7FeedForward):
540
+ layer_id = module.layer_id
541
+ num_hidden_layers = module.config.num_hidden_layers
542
+ hidden_size = module.config.hidden_size
543
+
544
+ ratio_1_to_almost0 = 1.0 - (layer_id / num_hidden_layers) # 1 to ~0
545
+
546
+ time_weight = torch.tensor(
547
+ [i / hidden_size for i in range(hidden_size)],
548
+ dtype=module.x_k.dtype,
549
+ device=module.x_k.device,
550
+ )
551
+ time_weight = time_weight[None, None, :]
552
+
553
+ with torch.no_grad():
554
+ module.x_k.copy_( 1.0 - torch.pow(time_weight, ratio_1_to_almost0**4) )
555
+
556
+ self.key.weight.data.uniform_(-0.5/(hidden_size**0.5), 0.5/(hidden_size**0.5))
557
+ self.value.weight.data.zero_()
558
+
559
+ @dataclass
560
+ class Rwkv7Output(ModelOutput):
561
+ """
562
+ Class for the RWKV model outputs.
563
+ Args:
564
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
565
+ Sequence of hidden-states at the output of the last layer of the model.
566
+ state (list of five `torch.FloatTensor` of shape `(batch_size, hidden_size, num_hidden_layers)`):
567
+ The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
568
+ avoid providing the old `input_ids`.
569
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
570
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
571
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of
572
+ the model at the output of each layer plus the optional initial embedding outputs.
573
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
574
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
575
+ sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
576
+ the self-attention heads.
577
+ """
578
+
579
+ last_hidden_state: torch.FloatTensor = None
580
+ state: Optional[List[torch.FloatTensor]] = None
581
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
582
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
583
+
584
+
585
+ @dataclass
586
+ class Rwkv7CausalLMOutput(ModelOutput):
587
+ """
588
+ Base class for causal language model (or autoregressive) outputs.
589
+ Args:
590
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
591
+ Language modeling loss (for next-token prediction).
592
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
593
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
594
+ state (list of five `torch.FloatTensor` of shape `(batch_size, hidden_size, num_hidden_layers)`):
595
+ The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
596
+ avoid providing the old `input_ids`.
597
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
598
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
599
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of
600
+ the model at the output of each layer plus the optional initial embedding outputs.
601
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
602
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
603
+ sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
604
+ the self-attention heads.
605
+ """
606
+
607
+ loss: Optional[torch.FloatTensor] = None
608
+ logits: torch.FloatTensor = None
609
+ state: Optional[List[torch.FloatTensor]] = None
610
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
611
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
612
+
613
+
614
+ RWKV7_START_DOCSTRING = r"""
615
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
616
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
617
+ etc.) This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module)
618
+ subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to
619
+ general usage and behavior.
620
+ Parameters:
621
+ config ([`Rwkv7Config`]): Model configuration class with all the parameters of the model.
622
+ Initializing with a config file does not load the weights associated with the model, only the
623
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
624
+ """
625
+
626
+ RWKV7_INPUTS_DOCSTRING = r"""
627
+ Args:
628
+ input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
629
+ `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
630
+ `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input
631
+ sequence tokens in the vocabulary. If `past_key_values` is used, only `input_ids` that do not have their
632
+ past calculated should be passed as `input_ids`. Indices can be obtained using [`AutoTokenizer`]. See
633
+ [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input
634
+ IDs?](../glossary#input-ids)
635
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
636
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
637
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
638
+ model's internal embedding lookup matrix.
639
+ state (tuple of five `torch.FloatTensor` of shape `(batch_size, hidden_size, num_hidden_layers)`, *optional*):
640
+ If passed along, the model uses the previous state in all the blocks (which will give the output for the
641
+ `input_ids` provided as if the model add `state_input_ids + input_ids` as context).
642
+ use_cache (`bool`, *optional*):
643
+ If set to `True`, the last state is returned and can be used to quickly generate the next logits.
644
+ output_attentions (`bool`, *optional*):
645
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
646
+ tensors for more detail.
647
+ output_hidden_states (`bool`, *optional*):
648
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
649
+ more detail.
650
+ return_dict (`bool`, *optional*):
651
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
652
+ """
653
+
654
+
655
+ @add_start_docstrings(
656
+ "The bare RWKV7 Model transformer outputting raw hidden-states without any specific head on top.",
657
+ RWKV7_START_DOCSTRING,
658
+ )
659
+ class Rwkv7Model(Rwkv7PreTrainedModel):
660
+ def __init__(self, config):
661
+ super().__init__(config)
662
+
663
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
664
+ self.pre_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
665
+ self.blocks = nn.ModuleList([Rwkv7Block(config, layer_id=idx) for idx in range(config.num_hidden_layers)])
666
+ self.ln_out = nn.LayerNorm(config.hidden_size)
667
+
668
+ self.gradient_checkpointing = False
669
+
670
+ # Initialize weights and apply final processing
671
+ self.post_init()
672
+
673
+ def get_input_embeddings(self):
674
+ return self.embeddings
675
+
676
+ def set_input_embeddings(self, new_embeddings):
677
+ self.embeddings = new_embeddings
678
+
679
+ @add_start_docstrings_to_model_forward(RWKV7_INPUTS_DOCSTRING)
680
+ @add_code_sample_docstrings(
681
+ checkpoint=_CHECKPOINT_FOR_DOC,
682
+ output_type=Rwkv7Output,
683
+ config_class=_CONFIG_FOR_DOC,
684
+ )
685
+ def forward(
686
+ self,
687
+ input_ids: Optional[torch.LongTensor] = None,
688
+ attention_mask: Optional[torch.LongTensor] = None, # noqa
689
+ inputs_embeds: Optional[torch.FloatTensor] = None,
690
+ state: Optional[List[torch.FloatTensor]] = None,
691
+ use_cache: Optional[bool] = None,
692
+ output_attentions: Optional[bool] = None,
693
+ output_hidden_states: Optional[bool] = None,
694
+ return_dict: Optional[bool] = None,
695
+ ) -> Union[Tuple, Rwkv7Output]:
696
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
697
+ output_hidden_states = (
698
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
699
+ )
700
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
701
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
702
+
703
+ if input_ids is not None and inputs_embeds is not None:
704
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
705
+ elif input_ids is None and inputs_embeds is None:
706
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
707
+
708
+ if inputs_embeds is None:
709
+ inputs_embeds = self.embeddings(input_ids)
710
+
711
+ if state is None:
712
+ state = []
713
+ head_size = self.config.head_size
714
+ num_heads = self.config.attention_hidden_size // head_size
715
+ state_attn_x = torch.zeros(
716
+ (self.config.num_hidden_layers, inputs_embeds.size(0), self.config.hidden_size),
717
+ dtype=inputs_embeds.dtype,
718
+ requires_grad=False,
719
+ device=inputs_embeds.device,
720
+ ).contiguous()
721
+ state_attn_vk = torch.zeros(
722
+ (
723
+ self.config.num_hidden_layers,
724
+ inputs_embeds.size(0),
725
+ num_heads,
726
+ head_size,
727
+ head_size,
728
+ ),
729
+ dtype=torch.float32,
730
+ requires_grad=False,
731
+ device=inputs_embeds.device,
732
+ ).contiguous()
733
+ state_ffn_x = torch.zeros(
734
+ (self.config.num_hidden_layers, inputs_embeds.size(0), self.config.hidden_size),
735
+ dtype=inputs_embeds.dtype,
736
+ requires_grad=False,
737
+ device=inputs_embeds.device,
738
+ ).contiguous()
739
+ state.append(state_attn_x)
740
+ state.append(state_attn_vk)
741
+ state.append(state_ffn_x)
742
+
743
+ seq_mode = inputs_embeds.shape[1] > 1
744
+ hidden_states = self.pre_ln(inputs_embeds)
745
+ v_first = None
746
+
747
+ all_self_attentions = () if output_attentions else None
748
+ all_hidden_states = () if output_hidden_states else None
749
+ for idx, block in enumerate(self.blocks):
750
+ hidden_states, state, v_first, attentions = block(
751
+ hidden_states, state=state, v_first=v_first, use_cache=use_cache, output_attentions=output_attentions, seq_mode=seq_mode
752
+ )
753
+
754
+ if output_hidden_states:
755
+ all_hidden_states = all_hidden_states + (hidden_states,)
756
+
757
+ if output_attentions:
758
+ all_self_attentions = all_self_attentions + (attentions,)
759
+
760
+ hidden_states = self.ln_out(hidden_states)
761
+
762
+ if output_hidden_states:
763
+ all_hidden_states = all_hidden_states + (hidden_states,)
764
+
765
+ if not return_dict:
766
+ return (hidden_states, state, all_hidden_states, all_self_attentions)
767
+
768
+ return Rwkv7Output(
769
+ last_hidden_state=hidden_states,
770
+ state=state,
771
+ hidden_states=all_hidden_states, # None
772
+ attentions=all_self_attentions, # None
773
+ )
774
+
775
+ # copied from HuggingFace https://github.com/huggingface/transformers/blob/main/src/transformers/models/rwkv/modeling_rwkv.py
776
+ @add_start_docstrings(
777
+ """
778
+ The RWKV7 Model transformer with a language modeling head on top (linear layer with weights tied to the input
779
+ embeddings).
780
+ """,
781
+ RWKV7_START_DOCSTRING,
782
+ )
783
+ class Rwkv7ForCausalLM(Rwkv7PreTrainedModel, GenerationMixin):
784
+ _tied_weights_keys = ["head.weight"]
785
+
786
+ def __init__(self, config):
787
+ super().__init__(config)
788
+ self.model = Rwkv7Model(config)
789
+ self.head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
790
+
791
+ # Initialize weights and apply final processing
792
+ self.post_init()
793
+
794
+ def get_output_embeddings(self):
795
+ return self.head
796
+
797
+ def set_output_embeddings(self, new_embeddings):
798
+ self.head = new_embeddings
799
+
800
+ def prepare_inputs_for_generation(self, input_ids, state=None, inputs_embeds=None, **kwargs):
801
+ # only last token for inputs_ids if the state is passed along.
802
+ if state is not None:
803
+ input_ids = input_ids[:, -1].unsqueeze(-1)
804
+
805
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
806
+ if inputs_embeds is not None and state is None:
807
+ model_inputs = {"inputs_embeds": inputs_embeds}
808
+ else:
809
+ model_inputs = {"input_ids": input_ids}
810
+
811
+ model_inputs["state"] = state
812
+ return model_inputs
813
+
814
+ @add_start_docstrings_to_model_forward(RWKV7_INPUTS_DOCSTRING)
815
+ @add_code_sample_docstrings(
816
+ checkpoint=_CHECKPOINT_FOR_DOC,
817
+ output_type=Rwkv7CausalLMOutput,
818
+ config_class=_CONFIG_FOR_DOC,
819
+ )
820
+ def forward(
821
+ self,
822
+ input_ids: Optional[torch.LongTensor] = None,
823
+ attention_mask: Optional[torch.LongTensor] = None,
824
+ inputs_embeds: Optional[torch.FloatTensor] = None,
825
+ state: Optional[List[torch.FloatTensor]] = None,
826
+ labels: Optional[torch.LongTensor] = None,
827
+ use_cache: Optional[bool] = None,
828
+ output_attentions: Optional[bool] = None,
829
+ output_hidden_states: Optional[bool] = None,
830
+ return_dict: Optional[bool] = None,
831
+ ) -> Union[Tuple, Rwkv7CausalLMOutput]:
832
+ r"""
833
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
834
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
835
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
836
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
837
+ """
838
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
839
+
840
+ outputs = self.model(
841
+ input_ids,
842
+ inputs_embeds=inputs_embeds,
843
+ state=state,
844
+ use_cache=use_cache,
845
+ output_attentions=output_attentions,
846
+ output_hidden_states=output_hidden_states,
847
+ return_dict=return_dict,
848
+ )
849
+ hidden_states = outputs[0]
850
+
851
+ logits = self.head(hidden_states)
852
+
853
+ loss = None
854
+ if labels is not None:
855
+ # move labels to correct device to enable model parallelism
856
+ labels = labels.to(logits.device)
857
+ # Shift so that tokens < n predict n
858
+ shift_logits = logits[..., :-1, :].contiguous()
859
+ shift_labels = labels[..., 1:].contiguous()
860
+ # Flatten the tokens
861
+ loss_fct = CrossEntropyLoss()
862
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
863
+
864
+ if not return_dict:
865
+ output = (logits,) + outputs[1:]
866
+ return ((loss,) + output) if loss is not None else output
867
+
868
+ return Rwkv7CausalLMOutput(
869
+ loss=loss,
870
+ logits=logits,
871
+ state=outputs.state,
872
+ hidden_states=outputs.hidden_states,
873
+ attentions=outputs.attentions,
874
+ )