SmerkyG commited on
Commit
683256b
·
verified ·
1 Parent(s): 68acb10

Update modeling_rwkv5.py

Browse files
Files changed (1) hide show
  1. modeling_rwkv5.py +237 -254
modeling_rwkv5.py CHANGED
@@ -1,6 +1,5 @@
1
  # coding=utf-8
2
- # Copyright 2023 Bo Peng and HuggingFace Inc. team.
3
- # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
  #
5
  # Licensed under the Apache License, Version 2.0 (the "License");
6
  # you may not use this file except in compliance with the License.
@@ -16,6 +15,7 @@
16
  """PyTorch RWKV5 World model."""
17
 
18
  from dataclasses import dataclass
 
19
  from typing import List, Optional, Tuple, Union
20
 
21
  import torch
@@ -30,6 +30,7 @@ from transformers.utils import (
30
  add_code_sample_docstrings,
31
  add_start_docstrings,
32
  add_start_docstrings_to_model_forward,
 
33
  is_ninja_available,
34
  is_torch_cuda_available,
35
  logging,
@@ -43,28 +44,23 @@ logger = logging.get_logger(__name__)
43
  _CHECKPOINT_FOR_DOC = "RWKV/rwkv-5-world-1b5"
44
  _CONFIG_FOR_DOC = "Rwkv5Config"
45
 
46
- RWKV5_PRETRAINED_MODEL_ARCHIVE_LIST = [
47
- "RWKV/rwkv-5-world-1b5",
48
- "RWKV/rwkv-5-world-3b",
49
- # See all RWKV models at https://huggingface.co/models?filter=rwkv
50
- ]
51
-
52
  rwkv5_cuda_kernel = None
53
 
54
 
 
55
  def load_wkv5_cuda_kernel(head_size):
56
  from torch.utils.cpp_extension import load as load_kernel
57
 
58
  global rwkv5_cuda_kernel
59
 
60
- kernel_folder = Path(__file__).resolve().parent.parent.parent / "kernels" / "rwkv5"
61
  cuda_kernel_files = [kernel_folder / f for f in ["wkv5_op.cpp", "wkv5_cuda.cu"]]
62
 
63
  # Only load the kernel if it's not been loaded yet or if we changed the context length
64
  if rwkv5_cuda_kernel is not None and rwkv5_cuda_kernel.head_size == head_size:
65
  return
66
 
67
- logger.info(f"Loading CUDA kernel for RWKV at head size of {head_size}.")
68
 
69
  flags = [
70
  "-res-usage",
@@ -84,200 +80,177 @@ def load_wkv5_cuda_kernel(head_size):
84
  rwkv5_cuda_kernel.head_size = head_size
85
 
86
 
87
- class WKV_5(torch.autograd.Function):
88
  @staticmethod
89
- def forward(ctx, B, T, C, H, r, k, v, w, u, s):
90
  with torch.no_grad():
91
- assert r.dtype == torch.bfloat16
92
- assert k.dtype == torch.bfloat16
93
- assert v.dtype == torch.bfloat16
94
- assert w.dtype == torch.bfloat16
95
- assert u.dtype == torch.bfloat16
96
- assert s.dtype == torch.float32
97
- ctx.B = B
98
- ctx.T = T
99
- ctx.C = C
100
- ctx.H = H
101
- assert r.is_contiguous()
102
- assert k.is_contiguous()
103
- assert v.is_contiguous()
104
- assert w.is_contiguous()
105
- assert u.is_contiguous()
106
- ew = (-torch.exp(w.float())).contiguous()
107
- eew = (torch.exp(ew)).contiguous()
108
- ctx.save_for_backward(r, k, v, eew, ew, u)
109
- y = torch.empty(
110
- (B, T, C), device=r.device, dtype=torch.bfloat16, memory_format=torch.contiguous_format
111
- ) # .uniform_(-1, 1)
112
- rwkv5_cuda_kernel.forward(B, T, C, H, r, k, v, eew, u, y, s)
113
- return y, s
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
  @staticmethod
116
- def backward(ctx, gy):
117
  with torch.no_grad():
118
- assert gy.dtype == torch.bfloat16
119
- B = ctx.B
120
- T = ctx.T
121
- C = ctx.C
122
- H = ctx.H
123
- assert gy.is_contiguous()
124
- r, k, v, eew, ew, u = ctx.saved_tensors
125
- gr = torch.empty(
126
- (B, T, C),
127
- device=gy.device,
 
 
 
128
  requires_grad=False,
129
  dtype=torch.bfloat16,
130
  memory_format=torch.contiguous_format,
131
- ) # .uniform_(-1, 1)
132
- gk = torch.empty(
133
- (B, T, C),
134
- device=gy.device,
135
  requires_grad=False,
136
  dtype=torch.bfloat16,
137
  memory_format=torch.contiguous_format,
138
- ) # .uniform_(-1, 1)
139
- gv = torch.empty(
140
- (B, T, C),
141
- device=gy.device,
142
  requires_grad=False,
143
  dtype=torch.bfloat16,
144
  memory_format=torch.contiguous_format,
145
- ) # .uniform_(-1, 1)
146
- gw = torch.empty(
147
- (B, C),
148
- device=gy.device,
149
  requires_grad=False,
150
  dtype=torch.bfloat16,
151
  memory_format=torch.contiguous_format,
152
- ) # .uniform_(-1, 1)
153
- gu = torch.empty(
154
- (B, C),
155
- device=gy.device,
156
  requires_grad=False,
157
  dtype=torch.bfloat16,
158
  memory_format=torch.contiguous_format,
159
- ) # .uniform_(-1, 1)
160
- rwkv5_cuda_kernel.backward(B, T, C, H, r, k, v, eew, ew, u, gy, gr, gk, gv, gw, gu)
161
- gw = torch.sum(gw, 0).view(H, C // H)
162
- gu = torch.sum(gu, 0).view(H, C // H)
163
- return (None, None, None, None, gr, gk, gv, gw, gu)
164
-
165
-
166
- def rwkv_linear_attention_v5_cpu(
167
- B,
168
- H,
169
- S,
170
- T,
171
- n_head,
172
- hidden,
173
- time_decay,
174
- time_first,
175
- receptance,
176
- key,
177
- value,
178
- gate,
179
- lxw,
180
- lxb,
181
- ow,
182
- state,
183
- ):
184
- key = key.to(torch.float32).view(B, T, H, S).transpose(1, 2).transpose(-2, -1)
185
- value = value.to(torch.float32).view(B, T, H, S).transpose(1, 2)
186
- receptance = receptance.to(torch.float32).view(B, T, H, S).transpose(1, 2)
187
- time_decay = torch.exp(-torch.exp(time_decay.float())).reshape(-1, 1, 1).reshape(n_head, -1, 1)
188
- time_first = time_first.float().reshape(-1, 1, 1).reshape(n_head, -1, 1)
189
- lxw = lxw.float()
190
- lxb = lxb.float()
191
- out = torch.zeros_like(key).reshape(B, T, H, S)
192
- for t in range(T):
193
- rt = receptance[:, :, t : t + 1, :]
194
- kt = key[:, :, :, t : t + 1]
195
- vt = value[:, :, t : t + 1, :]
196
- at = kt @ vt
197
- out[:, t] = (rt @ (time_first * at + state)).squeeze(2)
 
 
 
 
 
198
  with torch.no_grad():
199
- state = at + time_decay * state
200
-
201
- out = out.reshape(B * T, H * S)
202
- out = F.group_norm(out, num_groups=H, weight=lxw, bias=lxb).reshape(B, T, H * S)
203
- out = out.to(dtype=hidden.dtype) * gate
204
- out = out @ ow
205
 
206
  return out, state
207
 
208
-
209
- def rwkv_linear_attention(
210
- B,
211
- H,
212
- S,
213
- T,
214
- n_head,
215
- hidden,
216
- time_decay,
217
- time_first,
218
- receptance,
219
- key,
220
- value,
221
- gate,
222
- lxw,
223
- lxb,
224
- ow,
225
- state,
226
- ):
227
- no_cuda = any(t.device.type != "cuda" for t in [time_decay, time_first, receptance, key, value])
228
  # Launching the CUDA kernel for just one token will actually be slower (there is no for loop in the CPU version
229
  # in this case).
230
  one_token = key.size(1) == 1
231
- if rwkv5_cuda_kernel is None or no_cuda or one_token:
232
- return rwkv_linear_attention_v5_cpu(
233
- B,
234
- H,
235
- S,
236
- T,
237
- n_head,
238
- hidden,
239
- time_decay,
240
- time_first,
241
- receptance,
242
- key,
243
- value,
244
- gate,
245
- lxw,
246
- lxb,
247
- ow,
248
- state,
249
  )
250
  else:
251
- out, state = WKV_5.apply(B, T, H * S, H, receptance, key, value, time_decay, time_first, state)
252
- out = out.reshape(B * T, H * S)
253
- out = F.group_norm(out, num_groups=H, weight=lxw, bias=lxb).reshape(B, T, H * S)
254
- out = out.to(dtype=hidden.dtype) * gate
255
- out = out @ ow
256
- return out, state
257
 
258
 
259
- class RwkvSelfAttention(nn.Module):
260
  def __init__(self, config, layer_id=0):
261
  super().__init__()
262
  self.config = config
263
  kernel_loaded = rwkv5_cuda_kernel is not None and rwkv5_cuda_kernel.head_size == config.head_size
264
  if is_ninja_available() and is_torch_cuda_available() and not kernel_loaded:
265
  try:
266
- load_wkv5_cuda_kernel(config.context_length)
267
  except Exception:
268
  logger.info("Could not load the custom CUDA kernel for RWKV5 attention.")
269
  self.layer_id = layer_id
270
  hidden_size = config.hidden_size
271
- # https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v4neo/src/model.py#L146
272
- num_attention_heads = hidden_size // config.head_size
273
- self.num_attention_heads = num_attention_heads
274
- attention_hidden_size = (
275
- config.attention_hidden_size if config.attention_hidden_size is not None else hidden_size
276
- )
277
  self.attention_hidden_size = attention_hidden_size
 
 
278
 
279
- self.time_decay = nn.Parameter(torch.empty(num_attention_heads, config.head_size))
280
- self.time_faaaa = nn.Parameter(torch.empty(num_attention_heads, config.head_size))
281
  self.time_mix_gate = nn.Parameter(torch.empty(1, 1, hidden_size))
282
 
283
  self.time_mix_key = nn.Parameter(torch.empty(1, 1, hidden_size))
@@ -290,11 +263,9 @@ class RwkvSelfAttention(nn.Module):
290
  self.receptance = nn.Linear(hidden_size, attention_hidden_size, bias=False)
291
  self.gate = nn.Linear(hidden_size, attention_hidden_size, bias=False)
292
  self.output = nn.Linear(attention_hidden_size, hidden_size, bias=False)
293
- # https://github.com/BlinkDL/RWKV-LM/blob/3db37a72356b736966ddd377268f02b80963af3f/RWKV-v4neo/src/model.py#L190C1-L190C1
294
- self.ln_x = nn.GroupNorm(hidden_size // config.head_size, hidden_size)
295
 
296
- # TODO: maybe jit, otherwise move inside forward
297
- def extract_key_value(self, B, H, S, T, hidden, state=None):
298
  # Mix hidden with the previous timestep to produce key, value, receptance
299
  if hidden.size(1) == 1 and state is not None:
300
  shifted = state[0][:, :, self.layer_id]
@@ -304,12 +275,12 @@ class RwkvSelfAttention(nn.Module):
304
  shifted[:, 0] = state[0][:, :, self.layer_id]
305
  if len(shifted.size()) == 2:
306
  shifted = shifted.unsqueeze(1)
 
307
  key = hidden * self.time_mix_key + shifted * (1 - self.time_mix_key)
308
  value = hidden * self.time_mix_value + shifted * (1 - self.time_mix_value)
309
  receptance = hidden * self.time_mix_receptance + shifted * (1 - self.time_mix_receptance)
310
  gate = hidden * self.time_mix_gate + shifted * (1 - self.time_mix_gate)
311
 
312
- # https://github.com/BlinkDL/ChatRWKV/blob/main/rwkv_pip_package/src/rwkv/model.py#L693
313
  key = self.key(key)
314
  value = self.value(value)
315
  receptance = self.receptance(receptance)
@@ -321,45 +292,32 @@ class RwkvSelfAttention(nn.Module):
321
  return receptance, key, value, gate, state
322
 
323
  def forward(self, hidden, state=None, use_cache=False, seq_mode=True):
324
- B = hidden.shape[0]
325
- H = self.time_decay.shape[0]
326
- S = hidden.shape[-1] // H
327
- T = hidden.shape[1]
328
 
329
- receptance, key, value, gate, state = self.extract_key_value(B, H, S, T, hidden, state=state)
330
  layer_state = state[1][:, :, :, :, self.layer_id] if state is not None else None
331
- rwkv, layer_state = rwkv_linear_attention(
332
- B,
333
- H,
334
- S,
335
- T,
336
- self.num_attention_heads,
337
- hidden,
338
- self.time_decay,
339
- self.time_faaaa,
340
- receptance,
341
- key,
342
- value,
343
- gate,
344
- self.ln_x.weight,
345
- self.ln_x.bias,
346
- self.output.weight.t(),
347
- state=layer_state,
348
  )
349
 
350
  if layer_state is not None:
351
  state[1][:, :, :, :, self.layer_id] = layer_state
352
 
353
- return rwkv, state
354
-
 
 
 
355
 
356
- class RwkvFeedForward(nn.Module):
 
357
  def __init__(self, config, layer_id=0):
358
  super().__init__()
359
  self.config = config
360
  self.layer_id = layer_id
361
  hidden_size = config.hidden_size
362
- # https://github.com/BlinkDL/RWKV-LM/blob/3db37a72356b736966ddd377268f02b80963af3f/RWKV-v4neo/train.py#L168
363
  intermediate_size = (
364
  config.intermediate_size
365
  if config.intermediate_size is not None
@@ -396,7 +354,8 @@ class RwkvFeedForward(nn.Module):
396
  return receptance * value, state
397
 
398
 
399
- class RwkvBlock(nn.Module):
 
400
  def __init__(self, config, layer_id):
401
  super().__init__()
402
  self.config = config
@@ -408,8 +367,8 @@ class RwkvBlock(nn.Module):
408
  self.ln1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
409
  self.ln2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
410
 
411
- self.attention = RwkvSelfAttention(config, layer_id)
412
- self.feed_forward = RwkvFeedForward(config, layer_id)
413
 
414
  def forward(self, hidden, state=None, use_cache=False, output_attentions=False, seq_mode=True):
415
  if self.layer_id == 0:
@@ -429,6 +388,7 @@ class RwkvBlock(nn.Module):
429
  return outputs
430
 
431
 
 
432
  class Rwkv5PreTrainedModel(PreTrainedModel):
433
  """
434
  An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
@@ -436,19 +396,20 @@ class Rwkv5PreTrainedModel(PreTrainedModel):
436
  """
437
 
438
  config_class = Rwkv5Config
439
- base_model_prefix = "rwkv"
440
- _no_split_modules = ["RwkvBlock"]
441
  _keep_in_fp32_modules = ["time_decay", "time_first"]
442
  supports_gradient_checkpointing = True
443
 
444
  def _init_weights(self, module):
445
  """Initialize the weights."""
446
- if isinstance(module, RwkvSelfAttention):
447
  layer_id = module.layer_id
448
  num_hidden_layers = module.config.num_hidden_layers
449
  hidden_size = module.config.hidden_size
450
  attention_hidden_size = module.attention_hidden_size
451
- num_attention_heads = hidden_size // module.config.num_attention_heads
 
452
 
453
  ratio_0_to_1 = layer_id / (num_hidden_layers - 1) # 0 to 1
454
  ratio_1_to_almost0 = 1.0 - (layer_id / num_hidden_layers) # 1 to ~0
@@ -460,7 +421,6 @@ class Rwkv5PreTrainedModel(PreTrainedModel):
460
  )
461
  time_weight = time_weight[None, None, :]
462
 
463
- # https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v4neo/src/model.py#L398
464
  decay_speed = [
465
  -6.0 + 5.0 * (h / (attention_hidden_size - 1)) ** (0.7 + 1.3 * ratio_0_to_1)
466
  for h in range(attention_hidden_size)
@@ -476,15 +436,15 @@ class Rwkv5PreTrainedModel(PreTrainedModel):
476
  )
477
 
478
  with torch.no_grad():
479
- module.time_decay.data = decay_speed.reshape(num_attention_heads, module.config.num_attention_heads)
480
- module.time_faaaa.data = tmp.reshape(num_attention_heads, module.config.num_attention_heads)
481
  module.time_mix_key.data = torch.pow(time_weight, ratio_1_to_almost0)
482
 
483
  module.time_mix_value.data = torch.pow(time_weight, ratio_1_to_almost0) + 0.3 * ratio_0_to_1
484
  module.time_mix_receptance.data = torch.pow(time_weight, 0.5 * ratio_1_to_almost0)
485
  module.time_mix_gate.data = torch.pow(time_weight, 0.5 * ratio_1_to_almost0)
486
 
487
- elif isinstance(module, RwkvFeedForward):
488
  layer_id = module.layer_id
489
  num_hidden_layers = module.config.num_hidden_layers
490
  hidden_size = module.config.hidden_size
@@ -503,10 +463,11 @@ class Rwkv5PreTrainedModel(PreTrainedModel):
503
  module.time_mix_receptance.data = torch.pow(time_weight, ratio_1_to_almost0)
504
 
505
 
 
506
  @dataclass
507
  class Rwkv5Output(ModelOutput):
508
  """
509
- Class for the RWKV model outputs.
510
  Args:
511
  last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
512
  Sequence of hidden-states at the output of the last layer of the model.
@@ -529,6 +490,7 @@ class Rwkv5Output(ModelOutput):
529
  attentions: Optional[Tuple[torch.FloatTensor]] = None
530
 
531
 
 
532
  @dataclass
533
  class Rwkv5CausalLMOutput(ModelOutput):
534
  """
@@ -558,7 +520,7 @@ class Rwkv5CausalLMOutput(ModelOutput):
558
  attentions: Optional[Tuple[torch.FloatTensor]] = None
559
 
560
 
561
- RWKV_START_DOCSTRING = r"""
562
  This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
563
  library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
564
  etc.) This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module)
@@ -570,7 +532,7 @@ RWKV_START_DOCSTRING = r"""
570
  configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
571
  """
572
 
573
- RWKV_INPUTS_DOCSTRING = r"""
574
  Args:
575
  input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
576
  `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
@@ -600,15 +562,15 @@ RWKV_INPUTS_DOCSTRING = r"""
600
 
601
 
602
  @add_start_docstrings(
603
- "The bare RWKV Model transformer outputting raw hidden-states without any specific head on top.",
604
- RWKV_START_DOCSTRING,
605
  )
606
  class Rwkv5Model(Rwkv5PreTrainedModel):
607
  def __init__(self, config):
608
  super().__init__(config)
609
 
610
  self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
611
- self.blocks = nn.ModuleList([RwkvBlock(config, layer_id=idx) for idx in range(config.num_hidden_layers)])
612
  self.ln_out = nn.LayerNorm(config.hidden_size)
613
 
614
  self.layers_are_rescaled = False
@@ -623,7 +585,7 @@ class Rwkv5Model(Rwkv5PreTrainedModel):
623
  def set_input_embeddings(self, new_embeddings):
624
  self.embeddings = new_embeddings
625
 
626
- @add_start_docstrings_to_model_forward(RWKV_INPUTS_DOCSTRING)
627
  @add_code_sample_docstrings(
628
  checkpoint=_CHECKPOINT_FOR_DOC,
629
  output_type=Rwkv5Output,
@@ -644,6 +606,7 @@ class Rwkv5Model(Rwkv5PreTrainedModel):
644
  output_hidden_states = (
645
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
646
  )
 
647
  # rwkv5 only support inference in huggingface.
648
  use_cache = use_cache if use_cache is not None else self.config.use_cache
649
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
@@ -661,40 +624,37 @@ class Rwkv5Model(Rwkv5PreTrainedModel):
661
  if inputs_embeds is None:
662
  inputs_embeds = self.embeddings(input_ids)
663
 
664
- if use_cache and state is None:
665
- # https://github.com/BlinkDL/ChatRWKV/blob/main/rwkv_pip_package/src/rwkv/model.py#L904-L906
666
  state = []
667
- num_attention_heads = self.config.hidden_size // self.config.num_attention_heads
668
- state.append(
669
- torch.zeros(
670
- (inputs_embeds.size(0), self.config.hidden_size, self.config.num_hidden_layers),
671
- dtype=inputs_embeds.dtype,
672
- requires_grad=False,
673
- device=inputs_embeds.device,
674
- ).contiguous()
675
- )
676
- state.append(
677
- torch.zeros(
678
- (
679
- inputs_embeds.size(0),
680
- num_attention_heads,
681
- self.config.hidden_size // num_attention_heads,
682
- self.config.hidden_size // num_attention_heads,
683
- self.config.num_hidden_layers,
684
- ),
685
- dtype=torch.float32,
686
- requires_grad=False,
687
- device=inputs_embeds.device,
688
- ).contiguous()
689
- )
690
- state.append(
691
- torch.zeros(
692
- (inputs_embeds.size(0), self.config.hidden_size, self.config.num_hidden_layers),
693
- dtype=inputs_embeds.dtype,
694
- requires_grad=False,
695
- device=inputs_embeds.device,
696
- ).contiguous()
697
- )
698
 
699
  seq_mode = inputs_embeds.shape[1] > 1
700
  hidden_states = inputs_embeds
@@ -757,14 +717,37 @@ class Rwkv5Model(Rwkv5PreTrainedModel):
757
 
758
  self.layers_are_rescaled = not self.training
759
 
 
 
 
 
 
 
 
 
 
 
 
 
760
 
 
 
 
 
 
 
 
 
 
 
761
  @add_start_docstrings(
762
  """
763
- The RWKV Model transformer with a language modeling head on top (linear layer with weights tied to the input
764
  embeddings).
765
  """,
766
- RWKV_START_DOCSTRING,
767
  )
 
768
  class Rwkv5ForCausalLM(Rwkv5PreTrainedModel):
769
  _tied_weights_keys = ["head.weight"]
770
 
@@ -789,7 +772,7 @@ class Rwkv5ForCausalLM(Rwkv5PreTrainedModel):
789
  else:
790
  # add in \n at the beginning
791
  input_ids = torch.cat([torch.full([1,1],11,device=input_ids.device,dtype=input_ids.dtype), input_ids])
792
-
793
  # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
794
  if inputs_embeds is not None and state is None:
795
  model_inputs = {"inputs_embeds": inputs_embeds}
@@ -799,7 +782,7 @@ class Rwkv5ForCausalLM(Rwkv5PreTrainedModel):
799
  model_inputs["state"] = state
800
  return model_inputs
801
 
802
- @add_start_docstrings_to_model_forward(RWKV_INPUTS_DOCSTRING)
803
  @add_code_sample_docstrings(
804
  checkpoint=_CHECKPOINT_FOR_DOC,
805
  output_type=Rwkv5CausalLMOutput,
@@ -825,7 +808,7 @@ class Rwkv5ForCausalLM(Rwkv5PreTrainedModel):
825
  """
826
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
827
 
828
- rwkv_outputs = self.rwkv(
829
  input_ids,
830
  inputs_embeds=inputs_embeds,
831
  state=state,
@@ -834,7 +817,7 @@ class Rwkv5ForCausalLM(Rwkv5PreTrainedModel):
834
  output_hidden_states=output_hidden_states,
835
  return_dict=return_dict,
836
  )
837
- hidden_states = rwkv_outputs[0]
838
 
839
  logits = self.head(hidden_states)
840
 
@@ -850,13 +833,13 @@ class Rwkv5ForCausalLM(Rwkv5PreTrainedModel):
850
  loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
851
 
852
  if not return_dict:
853
- output = (logits,) + rwkv_outputs[1:]
854
  return ((loss,) + output) if loss is not None else output
855
 
856
  return Rwkv5CausalLMOutput(
857
  loss=loss,
858
  logits=logits,
859
- state=rwkv_outputs.state,
860
- hidden_states=rwkv_outputs.hidden_states,
861
- attentions=rwkv_outputs.attentions,
862
- )
 
1
  # coding=utf-8
2
+ # Copyright 2024 The RWKV team and HuggingFace Inc. team.
 
3
  #
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
  # you may not use this file except in compliance with the License.
 
15
  """PyTorch RWKV5 World model."""
16
 
17
  from dataclasses import dataclass
18
+ from pathlib import Path
19
  from typing import List, Optional, Tuple, Union
20
 
21
  import torch
 
30
  add_code_sample_docstrings,
31
  add_start_docstrings,
32
  add_start_docstrings_to_model_forward,
33
+ is_bitsandbytes_available,
34
  is_ninja_available,
35
  is_torch_cuda_available,
36
  logging,
 
44
  _CHECKPOINT_FOR_DOC = "RWKV/rwkv-5-world-1b5"
45
  _CONFIG_FOR_DOC = "Rwkv5Config"
46
 
 
 
 
 
 
 
47
  rwkv5_cuda_kernel = None
48
 
49
 
50
+ # Copied from https://github.com/huggingface/transformers/blob/18cbaf13dcaca7145f5652aefb9b19734c56c3cd/src/transformers/models/rwkv/modeling_rwkv.py#L65
51
  def load_wkv5_cuda_kernel(head_size):
52
  from torch.utils.cpp_extension import load as load_kernel
53
 
54
  global rwkv5_cuda_kernel
55
 
56
+ kernel_folder = Path(__file__).parent.resolve()
57
  cuda_kernel_files = [kernel_folder / f for f in ["wkv5_op.cpp", "wkv5_cuda.cu"]]
58
 
59
  # Only load the kernel if it's not been loaded yet or if we changed the context length
60
  if rwkv5_cuda_kernel is not None and rwkv5_cuda_kernel.head_size == head_size:
61
  return
62
 
63
+ logger.info(f"Loading CUDA kernel for RWKV5 at head size of {head_size}.")
64
 
65
  flags = [
66
  "-res-usage",
 
80
  rwkv5_cuda_kernel.head_size = head_size
81
 
82
 
83
+ class Rwkv5LinearAttention(torch.autograd.Function):
84
  @staticmethod
85
+ def forward(ctx, receptance, key, value, time_decay, time_first, state):
86
  with torch.no_grad():
87
+ assert receptance.dtype == torch.bfloat16
88
+ assert key.dtype == torch.bfloat16
89
+ assert value.dtype == torch.bfloat16
90
+ assert time_decay.dtype == torch.bfloat16
91
+ assert time_first.dtype == torch.bfloat16
92
+ assert state.dtype == torch.float32
93
+ batch, seq_length, hidden_size = key.shape
94
+ num_heads = time_decay.shape[0]
95
+ ctx.batch = batch
96
+ ctx.seq_length = seq_length
97
+ ctx.hidden_size = hidden_size
98
+ ctx.num_heads = num_heads
99
+ e_time_decay = (-torch.exp(time_decay.float())).contiguous()
100
+ ee_time_decay = (torch.exp(e_time_decay)).contiguous()
101
+ assert ee_time_decay.dtype == torch.float32
102
+ ctx.save_for_backward(receptance, key, value, ee_time_decay, e_time_decay, time_first)
103
+ out = torch.empty(
104
+ (batch, seq_length, hidden_size),
105
+ device=receptance.device,
106
+ dtype=torch.bfloat16,
107
+ memory_format=torch.contiguous_format,
108
+ )
109
+ state = state.clone()
110
+ rwkv5_cuda_kernel.forward_bf16(
111
+ batch,
112
+ seq_length,
113
+ hidden_size,
114
+ num_heads,
115
+ state,
116
+ receptance,
117
+ key,
118
+ value,
119
+ ee_time_decay,
120
+ time_first,
121
+ out,
122
+ )
123
+ return out, state
124
 
125
  @staticmethod
126
+ def backward(ctx, gout):
127
  with torch.no_grad():
128
+ assert gout.dtype == torch.bfloat16
129
+ batch = ctx.batch
130
+ seq_length = ctx.seq_length
131
+ hidden_size = ctx.hidden_size
132
+ num_heads = ctx.num_heads
133
+ receptance, key, value, ee_time_decay, e_time_decay, time_first = ctx.saved_tensors
134
+
135
+ global_shape = (batch, seq_length, hidden_size)
136
+
137
+ # TODO dtype should not be forced here IMO
138
+ greceptance = torch.empty(
139
+ global_shape,
140
+ device=gout.device,
141
  requires_grad=False,
142
  dtype=torch.bfloat16,
143
  memory_format=torch.contiguous_format,
144
+ )
145
+ g_key = torch.empty(
146
+ global_shape,
147
+ device=gout.device,
148
  requires_grad=False,
149
  dtype=torch.bfloat16,
150
  memory_format=torch.contiguous_format,
151
+ )
152
+ g_value = torch.empty(
153
+ global_shape,
154
+ device=gout.device,
155
  requires_grad=False,
156
  dtype=torch.bfloat16,
157
  memory_format=torch.contiguous_format,
158
+ )
159
+ g_time_decay = torch.empty(
160
+ (batch, hidden_size),
161
+ device=gout.device,
162
  requires_grad=False,
163
  dtype=torch.bfloat16,
164
  memory_format=torch.contiguous_format,
165
+ )
166
+ g_time_first = torch.empty(
167
+ (batch, hidden_size),
168
+ device=gout.device,
169
  requires_grad=False,
170
  dtype=torch.bfloat16,
171
  memory_format=torch.contiguous_format,
172
+ )
173
+ rwkv5_cuda_kernel.backward_bf16(
174
+ batch,
175
+ seq_length,
176
+ hidden_size,
177
+ num_heads,
178
+ receptance,
179
+ key,
180
+ value,
181
+ ee_time_decay,
182
+ e_time_decay,
183
+ time_first,
184
+ gout,
185
+ greceptance,
186
+ g_key,
187
+ g_value,
188
+ g_time_decay,
189
+ g_time_first,
190
+ )
191
+ head_size = hidden_size // num_heads
192
+ g_time_decay = torch.sum(g_time_decay, 0).view(num_heads, head_size)
193
+ g_time_first = torch.sum(g_time_first, 0).view(num_heads, head_size)
194
+ return (None, None, None, None, greceptance, g_key, g_value, g_time_decay, g_time_first)
195
+
196
+
197
+ def rwkv5_linear_attention_cpu(receptance, key, value, time_decay, time_first, state):
198
+ input_dtype = receptance.dtype
199
+ # For CPU fallback. Will be slower and probably take more memory than the custom CUDA kernel if not executed
200
+ # within a torch.no_grad.
201
+ batch, seq_length, hidden_size = receptance.shape
202
+ num_heads, head_size = time_first.shape
203
+ key = key.float().view(batch, seq_length, num_heads, head_size).transpose(1, 2).transpose(-2, -1)
204
+ value = value.float().view(batch, seq_length, num_heads, head_size).transpose(1, 2)
205
+ receptance = receptance.float().view(batch, seq_length, num_heads, head_size).transpose(1, 2)
206
+ time_decay = torch.exp(-torch.exp(time_decay.float())).reshape(-1, 1, 1).reshape(num_heads, -1, 1)
207
+ time_first = time_first.float().reshape(-1, 1, 1).reshape(num_heads, -1, 1)
208
+ out = torch.zeros_like(key).reshape(batch, seq_length, num_heads, head_size)
209
+
210
+ for current_index in range(seq_length):
211
+ current_receptance = receptance[:, :, current_index:current_index+1, :]
212
+ current_key = key[:, :, :, current_index:current_index+1]
213
+ current_value = value[:, :, current_index:current_index+1, :]
214
+ attention_output = current_key @ current_value
215
+ out[:, current_index] = (current_receptance @ (time_first * attention_output + state)).squeeze(2)
216
  with torch.no_grad():
217
+ state = attention_output + time_decay * state
 
 
 
 
 
218
 
219
  return out, state
220
 
221
+ # copied from RWKV but with receptance
222
+ def RWKV5_linear_attention(training, receptance, key, value, time_decay, time_first, state):
223
+ no_cuda = any(t.device.type != "cuda" for t in [time_decay, time_first, key, value])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
  # Launching the CUDA kernel for just one token will actually be slower (there is no for loop in the CPU version
225
  # in this case).
226
  one_token = key.size(1) == 1
227
+ if not training or rwkv5_cuda_kernel is None or no_cuda or one_token:
228
+ return rwkv5_linear_attention_cpu(
229
+ receptance, key, value, time_decay, time_first, state
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
  )
231
  else:
232
+ return Rwkv5LinearAttention.apply(receptance, key, value, time_decay, time_first, state)
 
 
 
 
 
233
 
234
 
235
+ class Rwkv5SelfAttention(nn.Module):
236
  def __init__(self, config, layer_id=0):
237
  super().__init__()
238
  self.config = config
239
  kernel_loaded = rwkv5_cuda_kernel is not None and rwkv5_cuda_kernel.head_size == config.head_size
240
  if is_ninja_available() and is_torch_cuda_available() and not kernel_loaded:
241
  try:
242
+ load_wkv5_cuda_kernel(config.head_size)
243
  except Exception:
244
  logger.info("Could not load the custom CUDA kernel for RWKV5 attention.")
245
  self.layer_id = layer_id
246
  hidden_size = config.hidden_size
247
+ attention_hidden_size = config.attention_hidden_size
 
 
 
 
 
248
  self.attention_hidden_size = attention_hidden_size
249
+ head_size = config.head_size
250
+ num_heads = attention_hidden_size // head_size
251
 
252
+ self.time_decay = nn.Parameter(torch.empty(num_heads, head_size))
253
+ self.time_faaaa = nn.Parameter(torch.empty(num_heads, head_size))
254
  self.time_mix_gate = nn.Parameter(torch.empty(1, 1, hidden_size))
255
 
256
  self.time_mix_key = nn.Parameter(torch.empty(1, 1, hidden_size))
 
263
  self.receptance = nn.Linear(hidden_size, attention_hidden_size, bias=False)
264
  self.gate = nn.Linear(hidden_size, attention_hidden_size, bias=False)
265
  self.output = nn.Linear(attention_hidden_size, hidden_size, bias=False)
266
+ self.ln_x = nn.GroupNorm(num_heads, hidden_size)
 
267
 
268
+ def extract_key_value(self, hidden, state=None):
 
269
  # Mix hidden with the previous timestep to produce key, value, receptance
270
  if hidden.size(1) == 1 and state is not None:
271
  shifted = state[0][:, :, self.layer_id]
 
275
  shifted[:, 0] = state[0][:, :, self.layer_id]
276
  if len(shifted.size()) == 2:
277
  shifted = shifted.unsqueeze(1)
278
+
279
  key = hidden * self.time_mix_key + shifted * (1 - self.time_mix_key)
280
  value = hidden * self.time_mix_value + shifted * (1 - self.time_mix_value)
281
  receptance = hidden * self.time_mix_receptance + shifted * (1 - self.time_mix_receptance)
282
  gate = hidden * self.time_mix_gate + shifted * (1 - self.time_mix_gate)
283
 
 
284
  key = self.key(key)
285
  value = self.value(value)
286
  receptance = self.receptance(receptance)
 
292
  return receptance, key, value, gate, state
293
 
294
  def forward(self, hidden, state=None, use_cache=False, seq_mode=True):
295
+ receptance, key, value, gate, state = self.extract_key_value(hidden, state=state)
296
+
297
+ B,T,C = receptance.shape
298
+ H, S = self.time_faaaa.shape
299
 
 
300
  layer_state = state[1][:, :, :, :, self.layer_id] if state is not None else None
301
+ out, layer_state = RWKV5_linear_attention(
302
+ self.training, receptance, key, value, self.time_decay, self.time_faaaa, layer_state
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
303
  )
304
 
305
  if layer_state is not None:
306
  state[1][:, :, :, :, self.layer_id] = layer_state
307
 
308
+ out = out.reshape(B * T, H * S)
309
+ out = F.group_norm(out / self.config.head_size_divisor, num_groups=H, weight=self.ln_x.weight.to(out.dtype), bias=self.ln_x.bias.to(out.dtype), eps=self.ln_x.eps).reshape(B, T, H * S)
310
+ out = out.to(dtype=hidden.dtype) * gate
311
+ out = self.output(out)
312
+ return out, state
313
 
314
+ # Copied from rwkv exceot for the intermediate size
315
+ class Rwkv5FeedForward(nn.Module):
316
  def __init__(self, config, layer_id=0):
317
  super().__init__()
318
  self.config = config
319
  self.layer_id = layer_id
320
  hidden_size = config.hidden_size
 
321
  intermediate_size = (
322
  config.intermediate_size
323
  if config.intermediate_size is not None
 
354
  return receptance * value, state
355
 
356
 
357
+ # Copied from transformers.models.rwkv.modeling_rwkv.RwkvBlock with Rwkv->Rwkv5
358
+ class Rwkv5Block(nn.Module):
359
  def __init__(self, config, layer_id):
360
  super().__init__()
361
  self.config = config
 
367
  self.ln1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
368
  self.ln2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
369
 
370
+ self.attention = Rwkv5SelfAttention(config, layer_id)
371
+ self.feed_forward = Rwkv5FeedForward(config, layer_id)
372
 
373
  def forward(self, hidden, state=None, use_cache=False, output_attentions=False, seq_mode=True):
374
  if self.layer_id == 0:
 
388
  return outputs
389
 
390
 
391
+ # Copied from transformers.models.rwkv.modeling_rwkv.RwkvPreTrainedModel with Rwkv->Rwkv5
392
  class Rwkv5PreTrainedModel(PreTrainedModel):
393
  """
394
  An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
 
396
  """
397
 
398
  config_class = Rwkv5Config
399
+ base_model_prefix = "rwkv5"
400
+ _no_split_modules = ["Rwkv5Block"]
401
  _keep_in_fp32_modules = ["time_decay", "time_first"]
402
  supports_gradient_checkpointing = True
403
 
404
  def _init_weights(self, module):
405
  """Initialize the weights."""
406
+ if isinstance(module, Rwkv5SelfAttention):
407
  layer_id = module.layer_id
408
  num_hidden_layers = module.config.num_hidden_layers
409
  hidden_size = module.config.hidden_size
410
  attention_hidden_size = module.attention_hidden_size
411
+ head_size = module.config.head_size
412
+ num_heads = attention_hidden_size // head_size
413
 
414
  ratio_0_to_1 = layer_id / (num_hidden_layers - 1) # 0 to 1
415
  ratio_1_to_almost0 = 1.0 - (layer_id / num_hidden_layers) # 1 to ~0
 
421
  )
422
  time_weight = time_weight[None, None, :]
423
 
 
424
  decay_speed = [
425
  -6.0 + 5.0 * (h / (attention_hidden_size - 1)) ** (0.7 + 1.3 * ratio_0_to_1)
426
  for h in range(attention_hidden_size)
 
436
  )
437
 
438
  with torch.no_grad():
439
+ module.time_decay.data = decay_speed.reshape(num_heads, head_size)
440
+ module.time_faaaa.data = tmp.reshape(num_heads, head_size)
441
  module.time_mix_key.data = torch.pow(time_weight, ratio_1_to_almost0)
442
 
443
  module.time_mix_value.data = torch.pow(time_weight, ratio_1_to_almost0) + 0.3 * ratio_0_to_1
444
  module.time_mix_receptance.data = torch.pow(time_weight, 0.5 * ratio_1_to_almost0)
445
  module.time_mix_gate.data = torch.pow(time_weight, 0.5 * ratio_1_to_almost0)
446
 
447
+ elif isinstance(module, Rwkv5FeedForward):
448
  layer_id = module.layer_id
449
  num_hidden_layers = module.config.num_hidden_layers
450
  hidden_size = module.config.hidden_size
 
463
  module.time_mix_receptance.data = torch.pow(time_weight, ratio_1_to_almost0)
464
 
465
 
466
+ # Copied from transformers.models.rwkv.modeling_rwkv.RwkvOutput with Rwkv->Rwkv5
467
  @dataclass
468
  class Rwkv5Output(ModelOutput):
469
  """
470
+ Class for the RWKV5 model outputs.
471
  Args:
472
  last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
473
  Sequence of hidden-states at the output of the last layer of the model.
 
490
  attentions: Optional[Tuple[torch.FloatTensor]] = None
491
 
492
 
493
+ # Copied from transformers.models.rwkv.modeling_rwkv.RwkvCausalLMOutput with Rwkv->Rwkv5
494
  @dataclass
495
  class Rwkv5CausalLMOutput(ModelOutput):
496
  """
 
520
  attentions: Optional[Tuple[torch.FloatTensor]] = None
521
 
522
 
523
+ RWKV5_START_DOCSTRING = r"""
524
  This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
525
  library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
526
  etc.) This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module)
 
532
  configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
533
  """
534
 
535
+ RWKV5_INPUTS_DOCSTRING = r"""
536
  Args:
537
  input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
538
  `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
 
562
 
563
 
564
  @add_start_docstrings(
565
+ "The bare RWKV5 Model transformer outputting raw hidden-states without any specific head on top.",
566
+ RWKV5_START_DOCSTRING,
567
  )
568
  class Rwkv5Model(Rwkv5PreTrainedModel):
569
  def __init__(self, config):
570
  super().__init__(config)
571
 
572
  self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
573
+ self.blocks = nn.ModuleList([Rwkv5Block(config, layer_id=idx) for idx in range(config.num_hidden_layers)])
574
  self.ln_out = nn.LayerNorm(config.hidden_size)
575
 
576
  self.layers_are_rescaled = False
 
585
  def set_input_embeddings(self, new_embeddings):
586
  self.embeddings = new_embeddings
587
 
588
+ @add_start_docstrings_to_model_forward(RWKV5_INPUTS_DOCSTRING)
589
  @add_code_sample_docstrings(
590
  checkpoint=_CHECKPOINT_FOR_DOC,
591
  output_type=Rwkv5Output,
 
606
  output_hidden_states = (
607
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
608
  )
609
+ # FIXME - training is supportable with the CUDA code
610
  # rwkv5 only support inference in huggingface.
611
  use_cache = use_cache if use_cache is not None else self.config.use_cache
612
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 
624
  if inputs_embeds is None:
625
  inputs_embeds = self.embeddings(input_ids)
626
 
627
+ if state is None:
 
628
  state = []
629
+ head_size = self.config.head_size
630
+ num_heads = self.config.attention_hidden_size // head_size
631
+ state_attn_x = torch.zeros(
632
+ (inputs_embeds.size(0), self.config.hidden_size, self.config.num_hidden_layers),
633
+ dtype=inputs_embeds.dtype,
634
+ requires_grad=False,
635
+ device=inputs_embeds.device,
636
+ ).contiguous()
637
+ state_attn_kv = torch.zeros(
638
+ (
639
+ inputs_embeds.size(0),
640
+ num_heads,
641
+ head_size,
642
+ head_size,
643
+ self.config.num_hidden_layers,
644
+ ),
645
+ dtype=torch.float32,
646
+ requires_grad=False,
647
+ device=inputs_embeds.device,
648
+ ).contiguous()
649
+ state_ffn_x = torch.zeros(
650
+ (inputs_embeds.size(0), self.config.hidden_size, self.config.num_hidden_layers),
651
+ dtype=inputs_embeds.dtype,
652
+ requires_grad=False,
653
+ device=inputs_embeds.device,
654
+ ).contiguous()
655
+ state.append(state_attn_x)
656
+ state.append(state_attn_kv)
657
+ state.append(state_ffn_x)
 
 
658
 
659
  seq_mode = inputs_embeds.shape[1] > 1
660
  hidden_states = inputs_embeds
 
717
 
718
  self.layers_are_rescaled = not self.training
719
 
720
+ def _bnb_4bit_dequantize_and_rescale(self, target_layer, block_id):
721
+ r"""
722
+ Perform the dequantization and rescaling of the weights of a given layer. After that operation the layer will
723
+ be quantized again.
724
+ """
725
+ if not is_bitsandbytes_available():
726
+ raise ImportError("Please install bitsandbytes to use this method.")
727
+ import bitsandbytes as bnb
728
+
729
+ dequant_weights = bnb.functional.dequantize_4bit(target_layer.weight.data, target_layer.weight.quant_state)
730
+
731
+ dequant_weights.div_(2 ** int(block_id // self.config.rescale_every))
732
 
733
+ # re-quantize the model:
734
+ # we need to put it first on CPU then back to the device
735
+ # this will create an overhead :/
736
+ # We set requires_grad=False as we cannot compute gradients on top of 4bit parameters anyway and to avoid
737
+ # bugs with bnb
738
+ quant_weight = bnb.nn.Params4bit(dequant_weights.to("cpu"), requires_grad=False).to(dequant_weights.device)
739
+ setattr(target_layer, "weight", quant_weight)
740
+
741
+
742
+ # copied from HuggingFace https://github.com/huggingface/transformers/blob/main/src/transformers/models/rwkv/modeling_rwkv.py
743
  @add_start_docstrings(
744
  """
745
+ The RWKV5 Model transformer with a language modeling head on top (linear layer with weights tied to the input
746
  embeddings).
747
  """,
748
+ RWKV5_START_DOCSTRING,
749
  )
750
+ # Copied from transformers.models.rwkv.modeling_rwkv.RwkvForCausalLM with Rwkv->Rwkv5
751
  class Rwkv5ForCausalLM(Rwkv5PreTrainedModel):
752
  _tied_weights_keys = ["head.weight"]
753
 
 
772
  else:
773
  # add in \n at the beginning
774
  input_ids = torch.cat([torch.full([1,1],11,device=input_ids.device,dtype=input_ids.dtype), input_ids])
775
+
776
  # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
777
  if inputs_embeds is not None and state is None:
778
  model_inputs = {"inputs_embeds": inputs_embeds}
 
782
  model_inputs["state"] = state
783
  return model_inputs
784
 
785
+ @add_start_docstrings_to_model_forward(RWKV5_INPUTS_DOCSTRING)
786
  @add_code_sample_docstrings(
787
  checkpoint=_CHECKPOINT_FOR_DOC,
788
  output_type=Rwkv5CausalLMOutput,
 
808
  """
809
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
810
 
811
+ outputs = self.rwkv(
812
  input_ids,
813
  inputs_embeds=inputs_embeds,
814
  state=state,
 
817
  output_hidden_states=output_hidden_states,
818
  return_dict=return_dict,
819
  )
820
+ hidden_states = outputs[0]
821
 
822
  logits = self.head(hidden_states)
823
 
 
833
  loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
834
 
835
  if not return_dict:
836
+ output = (logits,) + outputs[1:]
837
  return ((loss,) + output) if loss is not None else output
838
 
839
  return Rwkv5CausalLMOutput(
840
  loss=loss,
841
  logits=logits,
842
+ state=outputs.state,
843
+ hidden_states=outputs.hidden_states,
844
+ attentions=outputs.attentions,
845
+ )