SmerkyG commited on
Commit
c833b01
·
verified ·
1 Parent(s): 272ac41

Upload folder using huggingface_hub

Browse files
added_tokens.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "<s>": 0
3
+ }
config.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "Rwkv7ForCausalLM"
4
+ ],
5
+ "attention_hidden_size": 768,
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_rwkv7.Rwkv7Config",
8
+ "AutoModelForCausalLM": "modeling_rwkv7.Rwkv7ForCausalLM"
9
+ },
10
+ "bos_token_id": 0,
11
+ "eos_token_id": 0,
12
+ "head_size": 64,
13
+ "hidden_size": 768,
14
+ "intermediate_size": null,
15
+ "layer_norm_epsilon": 1e-05,
16
+ "lora_rank_decay": null,
17
+ "lora_rank_gate": null,
18
+ "lora_rank_iclr": null,
19
+ "lora_rank_value_residual_mix": null,
20
+ "model_type": "rwkv7",
21
+ "num_hidden_layers": 12,
22
+ "tie_word_embeddings": false,
23
+ "transformers_version": "4.46.2",
24
+ "use_cache": true,
25
+ "vocab_size": 65536
26
+ }
configuration_rwkv7.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The OpenAI Team Authors and HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """ RWKV configuration"""
17
+
18
+ from transformers.configuration_utils import PretrainedConfig
19
+ from transformers.utils import logging
20
+
21
+
22
+ logger = logging.get_logger(__name__)
23
+
24
+ RWKV7_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
25
+
26
+
27
+ class Rwkv7Config(PretrainedConfig):
28
+ """
29
+ This is the configuration class to store the configuration of a [`Rwkv7Model`]. It is used to instantiate a RWKV7
30
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
31
+ defaults will yield a similar configuration to that of the RWVK-7
32
+ [RWKV/v7-Goose-1.6B-Pile-HF](https://huggingface.co/RWKV/v7-Goose-1.6B-Pile-HF) architecture.
33
+
34
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
35
+ documentation from [`PretrainedConfig`] for more information.
36
+
37
+
38
+ Args:
39
+ vocab_size (`int`, *optional*, defaults to 65536):
40
+ Vocabulary size of the RWKV7 model. Defines the number of different tokens that can be represented by the
41
+ `inputs_ids` passed when calling [`Rwkv7Model`].
42
+ hidden_size (`int`, *optional*, defaults to 768):
43
+ Dimensionality of the embeddings and hidden states.
44
+ num_hidden_layers (`int`, *optional*, defaults to 24):
45
+ Number of hidden layers in the model.
46
+ attention_hidden_size (`int`, *optional*):
47
+ Dimensionality of the attention hidden states. Will default to `hidden_size` if unset.
48
+ num_attention_heads (`int`, *optional*, defaults to 64):
49
+ The attention heads to use in rwkv7 self_attention module.
50
+ head_size (`int`, *optional*, defaults to 64): head_size of rwkv7 self_attention module.
51
+ intermediate_size (`int`, *optional*):
52
+ Dimensionality of the inner feed-forward layers. Will default to 4 times `hidden_size` if unset.
53
+ layer_norm_epsilon (`float`, *optional*, defaults to 1e-05):
54
+ The epsilon to use in the layer normalization layers.
55
+ bos_token_id (`int`, *optional*, defaults to 0):
56
+ The id of the beginning of sentence token in the vocabulary. Defaults to 0.
57
+ eos_token_id (`int`, *optional*, defaults to 0):
58
+ The id of the end of sentence token in the vocabulary. Defaults to 0.
59
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
60
+ Whether or not to tie the word embeddings with the input token embeddings.
61
+ use_cache (`bool`, *optional*, defaults to `True`):
62
+ Whether or not the model should return the last state.
63
+
64
+
65
+ Example:
66
+
67
+ ```python
68
+ >>> from transformers import Rwkv7Config, Rwkv7Model
69
+
70
+ >>> # Initializing a Rwkv7 configuration
71
+ >>> configuration = Rwkv7Config()
72
+
73
+ >>> # Initializing a model (with random weights) from the configuration
74
+ >>> model = Rwkv7Model(configuration)
75
+
76
+ >>> # Accessing the model configuration
77
+ >>> configuration = model.config
78
+ ```"""
79
+
80
+ model_type = "rwkv7"
81
+
82
+ def __init__(
83
+ self,
84
+ vocab_size=65536,
85
+ hidden_size=768,
86
+ num_hidden_layers=24,
87
+ attention_hidden_size=None,
88
+ head_size=64,
89
+ intermediate_size=None,
90
+ lora_rank_decay=None,
91
+ lora_rank_iclr=None,
92
+ lora_rank_value_residual_mix=None,
93
+ lora_rank_gate=None,
94
+ layer_norm_epsilon=1e-5,
95
+ bos_token_id=0,
96
+ eos_token_id=0,
97
+ tie_word_embeddings=False,
98
+ use_cache=True,
99
+ **kwargs,
100
+ ):
101
+ self.vocab_size = vocab_size
102
+ self.hidden_size = hidden_size
103
+ self.num_hidden_layers = num_hidden_layers
104
+ self.attention_hidden_size = attention_hidden_size if attention_hidden_size is not None else hidden_size
105
+ self.head_size = head_size
106
+ self.intermediate_size = intermediate_size
107
+ self.lora_rank_decay = lora_rank_decay
108
+ self.lora_rank_iclr = lora_rank_iclr
109
+ self.lora_rank_value_residual_mix = lora_rank_value_residual_mix
110
+ self.lora_rank_gate = lora_rank_gate
111
+ self.layer_norm_epsilon = layer_norm_epsilon
112
+ self.use_cache = use_cache
113
+
114
+ super().__init__(
115
+ tie_word_embeddings=tie_word_embeddings, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs
116
+ )
generation_config.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "chat_format": "chatml",
3
+ "eos_token_id": 0,
4
+ "pad_token_id": 0,
5
+ "max_window_size": 4096,
6
+ "max_new_tokens": 4096,
7
+ "do_sample": true,
8
+ "top_k": 0,
9
+ "top_p": 0.1,
10
+ "repetition_penalty": 1.0,
11
+ "transformers_version": "4.31.1"
12
+ }
hf_rwkv_tokenizer.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Tokenization classes for RWKV6."""
16
+
17
+ import os
18
+ import re
19
+ from typing import TYPE_CHECKING, List, Optional, Tuple
20
+
21
+ from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
22
+ from transformers.utils import logging
23
+
24
+
25
+ if TYPE_CHECKING:
26
+ pass
27
+
28
+ logger = logging.get_logger(__name__)
29
+
30
+
31
+ VOCAB_FILES_NAMES = {
32
+ "vocab_file": "rwkv_vocab_v20230424.txt",
33
+ }
34
+
35
+ class TRIE:
36
+ __slots__ = tuple("ch,to,values,front".split(","))
37
+ to: list
38
+ values: set
39
+
40
+ def __init__(self, front=None, ch=None):
41
+ self.ch = ch
42
+ self.to = [None for ch in range(256)]
43
+ self.values = set()
44
+ self.front = front
45
+
46
+ def __repr__(self):
47
+ fr = self
48
+ ret = []
49
+ while fr != None:
50
+ if fr.ch != None:
51
+ ret.append(fr.ch)
52
+ fr = fr.front
53
+ return "<TRIE %s %s>" % (ret[::-1], self.values)
54
+
55
+ def add(self, key: bytes, idx: int = 0, val=None):
56
+ if idx == len(key):
57
+ if val is None:
58
+ val = key
59
+ self.values.add(val)
60
+ return self
61
+ ch = key[idx]
62
+ if self.to[ch] is None:
63
+ self.to[ch] = TRIE(front=self, ch=ch)
64
+ return self.to[ch].add(key, idx=idx + 1, val=val)
65
+
66
+ def find_longest(self, key: bytes, idx: int = 0):
67
+ u: TRIE = self
68
+ ch: int = key[idx]
69
+
70
+ while u.to[ch] is not None:
71
+ u = u.to[ch]
72
+ idx += 1
73
+ if u.values:
74
+ ret = idx, u, u.values
75
+ if idx == len(key):
76
+ break
77
+ ch = key[idx]
78
+ return ret
79
+
80
+
81
+ class RWKV_TOKENIZER:
82
+ def __init__(self, file_name):
83
+ self.idx2token = {}
84
+ sorted = [] # must be already sorted
85
+ with open(file_name, "r", encoding="utf-8") as f:
86
+ lines = f.readlines()
87
+ for l in lines:
88
+ idx = int(l[: l.index(" ")])
89
+ x = eval(l[l.index(" ") : l.rindex(" ")])
90
+ x = x.encode("utf-8") if isinstance(x, str) else x
91
+ assert isinstance(x, bytes)
92
+
93
+ assert len(x) == int(l[l.rindex(" ") :])
94
+ sorted += [x]
95
+ self.idx2token[idx] = x
96
+
97
+ self.token2idx = {}
98
+ for k, v in self.idx2token.items():
99
+ self.token2idx[v] = int(k)
100
+
101
+ self.root = TRIE()
102
+ for t, i in self.token2idx.items():
103
+ _ = self.root.add(t, val=(t, i))
104
+
105
+ def encodeBytes(self, src: bytes):
106
+ idx: int = 0
107
+ tokens = []
108
+ while idx < len(src):
109
+ _idx: int = idx
110
+ idx, _, values = self.root.find_longest(src, idx)
111
+ assert idx != _idx
112
+ _, token = next(iter(values))
113
+ tokens.append(token)
114
+ return tokens
115
+
116
+ def decodeBytes(self, tokens):
117
+ return b"".join(map(lambda i: self.idx2token[i], tokens))
118
+
119
+ def encode(self, src):
120
+ if isinstance(src, str):
121
+ return [self.encodeBytes(src.encode("utf-8"))]
122
+ elif isinstance(src, list):
123
+ return [self.encodeBytes(s.encode("utf-8")) for s in src]
124
+
125
+ def decode(self, tokens):
126
+ return [self.decodeBytes(batch).decode("utf-8") for batch in tokens]
127
+ # try:
128
+ # return self.decodeBytes(tokens).decode('utf-8')
129
+ # except:
130
+ # return '\ufffd' # bad utf-8
131
+
132
+ def printTokens(self, tokens):
133
+ for i in tokens:
134
+ s = self.idx2token[i]
135
+ try:
136
+ s = s.decode("utf-8")
137
+ except:
138
+ pass
139
+ print(f"{repr(s)}{i}", end=" ")
140
+ print()
141
+
142
+
143
+ class Rwkv6Tokenizer(PreTrainedTokenizer):
144
+ vocab_files_names = VOCAB_FILES_NAMES
145
+ model_input_names = ["input_ids", "attention_mask"]
146
+
147
+ def __init__(
148
+ self, vocab_file, bos_token="<s>", eos_token="<s>", unk_token="<s>", **kwargs
149
+ ):
150
+ if not os.path.isfile(vocab_file):
151
+ raise ValueError(
152
+ f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained"
153
+ " model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
154
+ )
155
+
156
+ with open(vocab_file, "r", encoding="utf-8") as reader:
157
+ tokens = reader.readlines()
158
+
159
+ if "add_bos_token" in kwargs:
160
+ self.add_bos_token = kwargs["add_bos_token"]
161
+ else:
162
+ self.add_bos_token = False
163
+ self.trie_tokenizer = RWKV_TOKENIZER(vocab_file)
164
+ vocab = self.trie_tokenizer.token2idx
165
+ self.encoder = vocab
166
+ self.decoder = {v: k for k, v in vocab.items()}
167
+ self._added_tokens_decoder = {0: AddedToken(str(bos_token))}
168
+ super().__init__(
169
+ bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, **kwargs
170
+ )
171
+
172
+ @property
173
+ def vocab_size(self):
174
+ return len(self.encoder)
175
+
176
+ def get_vocab(self):
177
+ vocab = {str(self.convert_ids_to_tokens(i)): i for i in range(self.vocab_size)}
178
+ vocab.update(self.added_tokens_encoder)
179
+ return vocab
180
+
181
+ def _tokenize(self, text, split_special_tokens=False):
182
+ # return self.wordpiece_tokenizer.tokenize(text.encode("utf-8"))
183
+ return self.trie_tokenizer.encode(text)[0]
184
+
185
+ def _convert_token_to_id(self, token):
186
+ return token
187
+
188
+ def _convert_id_to_token(self, index):
189
+ """Converts an index (integer) in a token (byte) using the vocab."""
190
+ token = self.decoder.get(index, self.unk_token)
191
+ if isinstance(token, (bytes)):
192
+ token = token.decode("utf-8", errors="replace")
193
+ return token
194
+
195
+ def convert_tokens_to_string(self, tokens):
196
+ """Converts a sequence of tokens (bytes) in a single string. Additional tokens are encoded to bytes"""
197
+ out_string = b"".join(
198
+ [k.encode(errors="replace") if isinstance(k, str) else k for k in tokens]
199
+ ).decode("utf-8")
200
+ return out_string
201
+
202
+ def save_vocabulary(
203
+ self, save_directory: str, filename_prefix: Optional[str] = None
204
+ ) -> Tuple[str]:
205
+ index = 0
206
+ if os.path.isdir(save_directory):
207
+ vocab_file = os.path.join(
208
+ save_directory,
209
+ (filename_prefix + "-" if filename_prefix else "") + "vocab.txt",
210
+ )
211
+ else:
212
+ vocab_file = (
213
+ filename_prefix + "-" if filename_prefix else ""
214
+ ) + save_directory
215
+ with open(vocab_file, "w", encoding="utf-8") as writer:
216
+ for token, token_index in sorted(
217
+ self.encoder.items(), key=lambda kv: kv[1]
218
+ ):
219
+ if index != token_index:
220
+ logger.warning(
221
+ f"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive."
222
+ " Please check that the vocabulary is not corrupted!"
223
+ )
224
+ index = token_index
225
+ writer.write(str(token) + "\n")
226
+ index += 1
227
+ return (vocab_file,)
228
+
229
+ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
230
+ if self.add_bos_token:
231
+ bos_token_ids = [self.bos_token_id]
232
+ else:
233
+ bos_token_ids = []
234
+
235
+ output = bos_token_ids + token_ids_0
236
+
237
+ if token_ids_1 is None:
238
+ return output
239
+
240
+ return output + bos_token_ids + token_ids_1
241
+
242
+ def get_special_tokens_mask(
243
+ self,
244
+ token_ids_0: List[int],
245
+ token_ids_1: Optional[List[int]] = None,
246
+ already_has_special_tokens: bool = False,
247
+ ) -> List[int]:
248
+ """
249
+ Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
250
+ special tokens using the tokenizer `prepare_for_model` or `encode_plus` methods.
251
+
252
+ Args:
253
+ token_ids_0 (`List[int]`):
254
+ List of IDs.
255
+ token_ids_1 (`List[int]`, *optional*):
256
+ Optional second list of IDs for sequence pairs.
257
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
258
+ Whether or not the token list is already formatted with special tokens for the model.
259
+
260
+ Returns:
261
+ `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
262
+ """
263
+ if already_has_special_tokens:
264
+ return super().get_special_tokens_mask(
265
+ token_ids_0=token_ids_0,
266
+ token_ids_1=token_ids_1,
267
+ already_has_special_tokens=True,
268
+ )
269
+
270
+ if not self.add_bos_token:
271
+ return super().get_special_tokens_mask(
272
+ token_ids_0=token_ids_0,
273
+ token_ids_1=token_ids_1,
274
+ already_has_special_tokens=False,
275
+ )
276
+
277
+ if token_ids_1 is None:
278
+ return [1] + ([0] * len(token_ids_0))
279
+ return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1))
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1e926b3212efaa0fd5a6544129f6a6edb7c614539cf3486c54c08020b45825f1
3
+ size 382110640
modeling_rwkv7.py ADDED
@@ -0,0 +1,875 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The RWKV team and HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch RWKV7 World model."""
16
+
17
+ from dataclasses import dataclass
18
+ from typing import List, Optional, Tuple, Union
19
+
20
+ from pathlib import Path
21
+
22
+ import math
23
+ import torch
24
+ import torch.nn.functional as F
25
+ import torch.utils.checkpoint
26
+ from torch import nn
27
+ from torch.nn import CrossEntropyLoss
28
+
29
+ from transformers.modeling_utils import PreTrainedModel, GenerationMixin, _init_weights
30
+ from transformers.utils import (
31
+ ModelOutput,
32
+ add_code_sample_docstrings,
33
+ add_start_docstrings,
34
+ add_start_docstrings_to_model_forward,
35
+ is_ninja_available,
36
+ is_torch_cuda_available,
37
+ logging,
38
+ )
39
+
40
+ from .configuration_rwkv7 import Rwkv7Config
41
+
42
+ # MIT License
43
+
44
+ # Copyright (c) 2024 Songlin Yang
45
+
46
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
47
+ # of this software and associated documentation files (the "Software"), to deal
48
+ # in the Software without restriction, including without limitation the rights
49
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
50
+ # copies of the Software, and to permit persons to whom the Software is
51
+ # furnished to do so, subject to the following conditions:
52
+
53
+ # The above copyright notice and this permission notice shall be included in all
54
+ # copies or substantial portions of the Software.
55
+
56
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
57
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
58
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
59
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
60
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
61
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
62
+ # SOFTWARE.
63
+
64
+ # Copyright (c) 2024, Johan Sokrates Wind
65
+
66
+ import torch as th
67
+ import triton
68
+ import triton.language as tl
69
+
70
+ @triton.jit
71
+ def IND4(a,b,c,d,nb,nc,nd):
72
+ return ((a*nb+b)*nc+c)*nd+d
73
+ @triton.jit
74
+ def IND5(a,b,c,d,e,nb,nc,nd,ne):
75
+ return (((a*nb+b)*nc+c)*nd+d)*ne+e
76
+
77
+ @triton.jit
78
+ def _prod(a,b): return a*b
79
+
80
+ # inv(I-A) where A is a strictly lower triangular nxn matrix
81
+ @triton.jit
82
+ def tri_minv(A, n:tl.constexpr, prec:tl.constexpr):
83
+ i = tl.arange(0,n)
84
+ prod = (i[None,:]==i[:,None]).to(tl.float32)
85
+ for j in range(n-1):
86
+ prod += tl_dot(prec, prod, (A*((i[None,:]==j)*(i[:,None]>i[None,:]))).trans())
87
+ return prod.trans()
88
+
89
+ @triton.jit
90
+ def fw_attn_triton(w_,q_,k_,v_,a_,b_, s0_,y_,s_,sT_, B:tl.constexpr,T:tl.constexpr,H:tl.constexpr,C:tl.constexpr,dT:tl.constexpr, prec:tl.constexpr):
91
+ bi = tl.program_id(1)
92
+ hi = tl.program_id(0)
93
+
94
+ i = tl.arange(0,C)[None,:]
95
+ state = tl.load(s0_+IND4(bi,hi,i.trans(),i, H,C,C)).to(tl.float32)
96
+ for t0 in range(T//dT):
97
+ t = t0*dT+tl.arange(0,dT)[:,None]
98
+ sw = tl.load(w_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32)
99
+ sq = tl.load(q_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32)
100
+ sk = tl.load(k_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32)
101
+ sv = tl.load(v_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32)
102
+ sa = tl.load(a_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32)
103
+ sb = tl.load(b_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32)
104
+
105
+ w = (-sw.exp()).exp()
106
+ fw = tl.reduce(w, 0, _prod, keep_dims=True)
107
+ incl_pref = tl.cumprod(w,axis=0)
108
+ non_incl_pref = incl_pref / w
109
+ inv_incl_pref = 1 / incl_pref
110
+
111
+ wq = sq * incl_pref
112
+ wa = sa * non_incl_pref
113
+ kwi = sk * inv_incl_pref
114
+ bwi = sb * inv_incl_pref
115
+
116
+ mask1 = (t > t.trans())
117
+ ab = tl_dot(prec, wa, bwi.trans()) * mask1
118
+ ak = tl_dot(prec, wa, kwi.trans()) * mask1
119
+
120
+ ab_inv = tri_minv(ab, dT, prec)
121
+
122
+ ab_u = tl_dot(prec, ak, sv) + tl_dot(prec, wa, state.trans())
123
+ u = tl_dot(prec, ab_inv, ab_u)
124
+ mask2 = (t >= t.trans())
125
+ qk = tl_dot(prec, wq, kwi.trans()) * mask2
126
+ qb = tl_dot(prec, wq, bwi.trans()) * mask2
127
+ yy = tl_dot(prec, qk, sv) + tl_dot(prec, qb, u) + tl_dot(prec, wq, state.trans())
128
+ tl.store(y_+IND4(bi,t,hi,i, T,H,C), yy.to(tl.bfloat16))
129
+
130
+ tl.store(s_+IND5(bi,hi,t0,i.trans(),i, H,T//dT,C,C), state.to(tl.float32))
131
+ state = state * fw + tl_dot(prec, sv.trans(), kwi*fw) + tl_dot(prec, u.trans(), bwi*fw)
132
+ tl.store(sT_+IND4(bi,hi,i.trans(),i, H,C,C), state.to(tl.bfloat16))
133
+
134
+ @triton.jit
135
+ def bw_attn_triton(w_,q_,k_,v_,a_,b_, dy_,s_,dsT_, dw_,dq_,dk_,dv_,da_,db_,ds0_, B:tl.constexpr,T:tl.constexpr,H:tl.constexpr,C:tl.constexpr,dT:tl.constexpr, prec:tl.constexpr):
136
+ bi = tl.program_id(1)
137
+ hi = tl.program_id(0)
138
+
139
+ i = tl.arange(0,C)[None,:]
140
+ dstate = tl.load(dsT_+IND4(bi,hi,i.trans(),i, H,C,C)).to(tl.float32)
141
+
142
+ for t0 in range(T//dT-1,-1,-1):
143
+ t = t0*dT+tl.arange(0,dT)[:,None]
144
+
145
+ state = tl.load(s_+IND5(bi,hi,t0,i.trans(),i, H,T//dT,C,C)).to(tl.float32)
146
+
147
+ sw = tl.load(w_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32)
148
+ sq = tl.load(q_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32)
149
+ sk = tl.load(k_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32)
150
+ sv = tl.load(v_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32)
151
+ sa = tl.load(a_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32)
152
+ sb = tl.load(b_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32)
153
+ sdy = tl.load(dy_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32)
154
+
155
+ dw_fac = -sw.exp()
156
+ w = dw_fac.exp()
157
+ fw = tl.reduce(w, 0, _prod, keep_dims=True)
158
+ incl_pref = tl.cumprod(w,axis=0)
159
+ non_incl_pref = incl_pref / w
160
+ inv_incl_pref = 1 / incl_pref
161
+
162
+ wq = sq * incl_pref
163
+ wa = sa * non_incl_pref
164
+ kwi = sk * inv_incl_pref
165
+ bwi = sb * inv_incl_pref
166
+
167
+ mask1 = (t > t.trans())
168
+ ab = tl_dot(prec, wa, bwi.trans()) * mask1
169
+ ak = tl_dot(prec, wa, kwi.trans()) * mask1
170
+
171
+ ab_inv = tri_minv(ab, dT, prec)
172
+
173
+ ab_u = tl_dot(prec, ak, sv) + tl_dot(prec, wa, state.trans())
174
+ u = tl_dot(prec, ab_inv, ab_u)
175
+ mask2 = (t >= t.trans())
176
+ qk = tl_dot(prec, wq, kwi.trans()) * mask2
177
+ qb = tl_dot(prec, wq, bwi.trans()) * mask2
178
+
179
+ du = tl_dot(prec, qb.trans(), sdy) + tl_dot(prec, bwi*fw, dstate.trans())
180
+ dab_u = tl_dot(prec, ab_inv.trans(), du)
181
+
182
+ dv = tl_dot(prec, qk.trans(), sdy) + tl_dot(prec, kwi*fw, dstate.trans()) + tl_dot(prec, ak.trans(), dab_u)
183
+ tl.store(dv_+IND4(bi,t,hi,i, T,H,C), dv.to(tl.bfloat16))
184
+
185
+ dab = tl_dot(prec, tl_dot(prec, ab_inv.trans(), du), u.trans()) * mask1
186
+ dak = tl_dot(prec, dab_u, sv.trans()) * mask1
187
+ dab_u_state = tl_dot(prec, dab_u, state)
188
+ da = non_incl_pref * (tl_dot(prec, dab, bwi) + tl_dot(prec, dak, kwi) + dab_u_state)
189
+ tl.store(da_+IND4(bi,t,hi,i, T,H,C), da.to(tl.bfloat16))
190
+
191
+ dqb = tl_dot(prec, sdy, u.trans()) * mask2
192
+ dqk = tl_dot(prec, sdy, sv.trans()) * mask2
193
+ dy_state = tl_dot(prec, sdy, state)
194
+ dq = incl_pref * (tl_dot(prec, dqb, bwi) + tl_dot(prec, dqk, kwi) + dy_state)
195
+ tl.store(dq_+IND4(bi,t,hi,i, T,H,C), dq.to(tl.bfloat16))
196
+
197
+ fw_u_dstate = fw * tl_dot(prec, u, dstate)
198
+ db = inv_incl_pref * (tl_dot(prec, dab.trans(), wa) + tl_dot(prec, dqb.trans(), wq) + fw_u_dstate)
199
+ tl.store(db_+IND4(bi,t,hi,i, T,H,C), db.to(tl.bfloat16))
200
+
201
+ fw_v_dstate = fw * tl_dot(prec, sv, dstate)
202
+ dk = inv_incl_pref * (tl_dot(prec, dak.trans(), wa) + tl_dot(prec, dqk.trans(), wq) + fw_v_dstate)
203
+ tl.store(dk_+IND4(bi,t,hi,i, T,H,C), dk.to(tl.bfloat16))
204
+
205
+ dw0 = fw * tl.sum(state*dstate, axis=0,keep_dims=True)
206
+ for k in range(t0*dT,t0*dT+dT):
207
+ lmask = (t<k).trans()
208
+ A = (tl_dot(prec, dab*lmask, bwi) + tl_dot(prec, dak*lmask, kwi)) * wa * (t>k)
209
+ A += (tl_dot(prec, dqb*lmask, bwi) + tl_dot(prec, dqk*lmask, kwi)) * wq * (t>=k)
210
+ A += (fw_v_dstate*kwi + fw_u_dstate*bwi) * (t<k)
211
+ A += dab_u_state*wa * (t>k) + dy_state*wq * (t>=k)
212
+ dw = tl.sum(A, axis=0,keep_dims=True) + dw0
213
+
214
+ wk = tl.load(w_+IND4(bi,k,hi,i, T,H,C)).to(tl.float32)
215
+ dw *= -wk.exp()
216
+ tl.store(dw_+IND4(bi,k,hi,i, T,H,C), dw.to(tl.bfloat16))
217
+
218
+ dstate = dstate * fw + tl_dot(prec, sdy.trans(), wq) + tl_dot(prec, dab_u.trans(), wa)
219
+ tl.store(ds0_+IND4(bi,hi,i.trans(),i, H,C,C), dstate.to(tl.bfloat16))
220
+
221
+
222
+ class TritonRWKV7(th.autograd.Function):
223
+ @staticmethod
224
+ def forward(ctx, w,q,k,v,z,b,s0, dot_prec):
225
+ K = 16
226
+ B,T,H,C = w.shape
227
+ s0 = th.zeros(B,H,C,C, dtype=w.dtype,device=w.device) if s0 is None else s0
228
+ y = th.empty_like(v)
229
+ sT = th.empty_like(s0)
230
+ s = th.zeros(B,H,T//K,C,C, dtype=th.float32,device=w.device)
231
+ fw_attn_triton[(H,B)](w,q,k,v,z,b, s0,y,s,sT, B,T,H,C,K, dot_prec)
232
+ ctx.dot_prec = dot_prec
233
+ ctx.save_for_backward(w,q,k,v,z,b,s)
234
+ return y, sT
235
+ @staticmethod
236
+ def backward(ctx, dy, dsT):
237
+ K = 16
238
+ w,q,k,v,z,b,s = ctx.saved_tensors
239
+ B,T,H,C = w.shape
240
+ dw,dq,dk,dv,dz,db,ds0 = [th.empty_like(x) for x in [w,q,k,v,z,b,dsT]]
241
+ bw_attn_triton[(H,B)](w,q,k,v,z,b, dy,s,dsT, dw,dq,dk,dv,dz,db,ds0, B,T,H,C,K, ctx.dot_prec)
242
+ return dw,dq,dk,dv,dz,db,ds0,None
243
+
244
+ @triton.jit
245
+ def tl_dot(prec:tl.constexpr, a, b) -> torch.Tensor:
246
+ if prec == 'fp32':
247
+ return tl.dot(a.to(tl.float32),b.trans().to(tl.float32).trans(), allow_tf32=False)
248
+ elif prec == 'tf32':
249
+ return tl.dot(a.to(tl.float32),b.trans().to(tl.float32).trans(), allow_tf32=True)
250
+ elif prec == 'bf16':
251
+ return tl.dot(a.to(tl.bfloat16),b.trans().to(tl.bfloat16).trans(), allow_tf32=True)
252
+ else:
253
+ tl.static_assert(False)
254
+
255
+ def rwkv7_attn_triton(r,w,k,v,a,b, HEAD_SIZE, dot_prec = 'fp32'):
256
+ B,T,HC = w.shape
257
+ C = HEAD_SIZE
258
+ H = HC//C
259
+ r,w,k,v,a,b = [i.view(B,T,H,C) for i in [r,w,k,v,a,b]]
260
+ s0 = th.zeros(B,H,C,C, dtype=th.bfloat16,device=w.device)
261
+ return TritonRWKV7.apply(w,r,k,v,a,b,s0,dot_prec)[0].view(B,T,HC)
262
+
263
+ logger = logging.get_logger(__name__)
264
+
265
+ _CHECKPOINT_FOR_DOC = "RWKV/v7-Goose-1.6B-Pile-HF"
266
+ _CONFIG_FOR_DOC = "Rwkv7Config"
267
+
268
+ class Rwkv7SelfAttention(nn.Module):
269
+ def __init__(self, config, layer_id=0):
270
+ super().__init__()
271
+ self.config = config
272
+ self.layer_id = layer_id
273
+ C = hidden_size = config.hidden_size
274
+ attention_hidden_size = config.attention_hidden_size
275
+ self.attention_hidden_size = attention_hidden_size
276
+ H = self.num_heads = attention_hidden_size // config.head_size
277
+ N = self.head_size = config.head_size
278
+
279
+ calc_lora_rank = lambda exponent, multiplier: max(1, round(hidden_size ** exponent * multiplier / 32)) * 32
280
+ lora_rank_decay = config.lora_rank_decay or calc_lora_rank(0.5, 1.8)
281
+ lora_rank_iclr = config.lora_rank_iclr or calc_lora_rank(0.5, 1.8)
282
+ lora_rank_value_residual_mix = config.lora_rank_value_residual_mix or calc_lora_rank(0.5, 1.3)
283
+ lora_rank_gate = config.lora_rank_gate or calc_lora_rank(0.8, 0.6)
284
+
285
+ self.x_r = nn.Parameter(torch.empty(1,1,C))
286
+ self.x_w = nn.Parameter(torch.empty(1,1,C))
287
+ self.x_k = nn.Parameter(torch.empty(1,1,C))
288
+ self.x_v = nn.Parameter(torch.empty(1,1,C))
289
+ self.x_a = nn.Parameter(torch.empty(1,1,C))
290
+ self.x_g = nn.Parameter(torch.empty(1,1,C))
291
+
292
+ self.w0 = nn.Parameter(torch.empty(1,1,C))
293
+ self.w1 = nn.Parameter(torch.empty(C, lora_rank_decay))
294
+ self.w2 = nn.Parameter(torch.empty(lora_rank_decay, C))
295
+
296
+ self.a0 = nn.Parameter(torch.empty(1,1,C))
297
+ self.a1 = nn.Parameter(torch.empty(C, lora_rank_iclr))
298
+ self.a2 = nn.Parameter(torch.empty(lora_rank_iclr, C))
299
+
300
+ if layer_id > 0:
301
+ self.v0 = nn.Parameter(torch.empty(1,1,C))
302
+ self.v1 = nn.Parameter(torch.empty(C, lora_rank_value_residual_mix))
303
+ self.v2 = nn.Parameter(torch.empty(lora_rank_value_residual_mix, C))
304
+
305
+ self.g1 = nn.Parameter(torch.empty(C, lora_rank_gate))
306
+ self.g2 = nn.Parameter(torch.empty(lora_rank_gate, C))
307
+
308
+ self.k_k = nn.Parameter(torch.empty(1,1,C))
309
+ self.k_a = nn.Parameter(torch.empty(1,1,C))
310
+ self.r_k = nn.Parameter(torch.empty(H,N))
311
+
312
+ self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
313
+ self.receptance = nn.Linear(C, C, bias=False)
314
+ self.key = nn.Linear(C, C, bias=False)
315
+ self.value = nn.Linear(C, C, bias=False)
316
+ self.output = nn.Linear(C, C, bias=False)
317
+ self.ln_x = nn.GroupNorm(H, C, eps=self.head_size * 1e-5)
318
+
319
+
320
+ def forward(self, hidden, state=None, v_first=None, use_cache=False, seq_mode=True):
321
+ # Mix hidden with the previous timestep to produce key, value, receptance
322
+ if hidden.size(1) == 1 and state is not None:
323
+ shifted = state[0][self.layer_id]
324
+ else:
325
+ shifted = self.time_shift(hidden)
326
+ if state is not None:
327
+ shifted[:, 0] = state[0][self.layer_id]
328
+ if len(shifted.size()) == 2:
329
+ shifted = shifted.unsqueeze(1)
330
+
331
+ x = hidden
332
+
333
+ B, T, C = hidden.shape
334
+ H = self.num_heads
335
+ N = self.head_size
336
+
337
+ xx = shifted - x
338
+
339
+ xr = x+xx*self.x_r
340
+ xw = x+xx*self.x_w
341
+ xk = x+xx*self.x_k
342
+ xv = x+xx*self.x_v
343
+ xa = x+xx*self.x_a
344
+ xg = x+xx*self.x_g
345
+
346
+ r = self.receptance(xr)
347
+ w = torch.tanh(xw @ self.w1) @ self.w2
348
+ k = self.key(xk)
349
+ v = self.value(xv)
350
+ a = torch.sigmoid(self.a0 + (xa @ self.a1) @ self.a2)
351
+ g = torch.sigmoid(xg @ self.g1) @ self.g2
352
+
353
+ kk = torch.nn.functional.normalize((k * self.k_k).view(B,T,H,-1), dim=-1, p=2.0).view(B,T,-1)
354
+ k = k * (1 + (a-1) * self.k_a)
355
+ if self.layer_id == 0: v_first = v
356
+ else: v = v + (v_first - v) * torch.sigmoid(self.v0 + (xv @ self.v1) @ self.v2)
357
+
358
+ if T == 1 or not self.training:
359
+ w = torch.exp(-0.606531 * torch.sigmoid((self.w0 + w).float())) # 0.606531 = exp(-0.5)
360
+ vk_state = state[1][self.layer_id]
361
+ for t in range(T):
362
+ r_, w_, k_, v_, kk_, a_ = r[:,t], w[:,t], k[:,t], v[:,t], kk[:,t], a[:,t]
363
+ vk = v_.view(B,H,N,1) @ k_.view(B,H,1,N)
364
+ ab = (-kk_).view(B,H,N,1) @ (kk_*a_).view(B,H,1,N)
365
+ vk_state = vk_state * w_.view(B,H,1,N) + vk_state @ ab.float() + vk.float()
366
+ xx[:,t] = (vk_state.to(dtype=x.dtype) @ r_.view(B,H,N,1)).view(B,H*N)
367
+ state[1][self.layer_id] = vk_state
368
+ # FIXME - support fast triton kernel for non-training pre-fill with state in and out
369
+ else:
370
+ w = -torch.nn.functional.softplus(-(self.w0 + w)) - 0.5
371
+ rwkv7_attn_triton(r, w, k, v, -kk, kk*a, self.head_size)
372
+
373
+ xx = torch.nn.functional.group_norm(xx.view(B*T,H*N), num_groups=H, weight=self.ln_x.weight, bias=self.ln_x.bias, eps = self.ln_x.eps).view(B,T,H*N)
374
+ #x = x + ((r * k * self.r_k).view(B,T,H,N).sum(dim=-1, keepdim=True) * v.view(B,T,H,N)).view(B,T,H*N)
375
+ xx = xx + ((r.view(B,T,H,-1)*k.view(B,T,H,-1)*self.r_k).sum(dim=-1, keepdim=True) * v.view(B,T,H,-1)).view(B,T,C)
376
+ xx = self.output(xx * g)
377
+
378
+ if state is not None:
379
+ state[0][self.layer_id] = hidden[:, -1]
380
+
381
+ return xx, state, v_first
382
+
383
+
384
+ class Rwkv7FeedForward(nn.Module):
385
+ def __init__(self, config, layer_id=0):
386
+ super().__init__()
387
+ self.config = config
388
+ self.layer_id = layer_id
389
+ hidden_size = config.hidden_size
390
+ intermediate_size = (
391
+ config.intermediate_size
392
+ if config.intermediate_size is not None
393
+ else int(config.hidden_size * 4)
394
+ )
395
+
396
+
397
+ self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
398
+
399
+ self.x_k = nn.Parameter(torch.empty(1, 1, hidden_size))
400
+
401
+ self.key = nn.Linear(hidden_size, intermediate_size, bias=False)
402
+ self.value = nn.Linear(intermediate_size, hidden_size, bias=False)
403
+
404
+ def forward(self, hidden, state=None):
405
+ if hidden.size(1) == 1 and state is not None:
406
+ shifted = state[2][self.layer_id]
407
+ else:
408
+ shifted = self.time_shift(hidden)
409
+ if state is not None:
410
+ shifted[:, 0] = state[2][self.layer_id]
411
+ if len(shifted.size()) == 2:
412
+ shifted = shifted.unsqueeze(1)
413
+
414
+ delta_hidden_to_shifted = shifted - hidden
415
+ key = hidden + delta_hidden_to_shifted * self.x_k
416
+
417
+ key = torch.square(torch.relu(self.key(key)))
418
+ value = self.value(key)
419
+
420
+ if state is not None:
421
+ state[2][self.layer_id] = hidden[:, -1]
422
+
423
+ return value, state
424
+
425
+
426
+ class Rwkv7Block(nn.Module):
427
+ def __init__(self, config, layer_id):
428
+ super().__init__()
429
+ self.config = config
430
+ self.layer_id = layer_id
431
+
432
+ self.ln1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
433
+ self.ln2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
434
+
435
+ self.attention = Rwkv7SelfAttention(config, layer_id)
436
+ self.feed_forward = Rwkv7FeedForward(config, layer_id)
437
+
438
+ def forward(self, hidden, state=None, v_first=None, use_cache=False, output_attentions=False, seq_mode=True):
439
+ attention, state, v_first = self.attention(self.ln1(hidden), state=state, v_first=v_first, use_cache=use_cache, seq_mode=seq_mode)
440
+ hidden = hidden + attention
441
+
442
+ feed_forward, state = self.feed_forward(self.ln2(hidden), state=state)
443
+ hidden = hidden + feed_forward
444
+
445
+ outputs = (hidden, state, v_first)
446
+ if output_attentions:
447
+ outputs += (attention,)
448
+ else:
449
+ outputs += (None,)
450
+
451
+ return outputs
452
+
453
+
454
+ class Rwkv7PreTrainedModel(PreTrainedModel):
455
+ """
456
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
457
+ models.
458
+ """
459
+
460
+ config_class = Rwkv7Config
461
+ base_model_prefix = "rwkv7"
462
+ _no_split_modules = ["Rwkv7Block"]
463
+ _keep_in_fp32_modules = []
464
+ supports_gradient_checkpointing = True
465
+
466
+ def _init_weights(self, module):
467
+ return
468
+
469
+ """Initialize the weights."""
470
+ if isinstance(module, Rwkv7SelfAttention):
471
+ layer_id = module.layer_id
472
+ num_hidden_layers = module.config.num_hidden_layers
473
+ hidden_size = module.config.hidden_size
474
+ attention_hidden_size = module.attention_hidden_size
475
+ head_size = module.config.head_size
476
+ num_heads = attention_hidden_size // head_size
477
+
478
+ ratio_0_to_1 = layer_id / (num_hidden_layers - 1) # 0 to 1
479
+ ratio_1_to_almost0 = 1.0 - (layer_id / num_hidden_layers) # 1 to ~0
480
+
481
+ time_weight = torch.tensor(
482
+ [i / hidden_size for i in range(hidden_size)],
483
+ dtype=module.x_k.dtype,
484
+ device=module.x_k.device,
485
+ )
486
+ time_weight = time_weight[None, None, :]
487
+
488
+ decay_speed = [
489
+ -7.0 + 5.0 * (n / (attention_hidden_size - 1)) ** (0.85 + 1.0 * ratio_0_to_1 ** 0.5)
490
+ for n in range(attention_hidden_size)
491
+ ]
492
+ decay_speed = torch.tensor(decay_speed, dtype=module.w0.dtype, device=module.w0.device)
493
+
494
+ with torch.no_grad():
495
+ module.x_r.copy_( 1.0 - torch.pow(time_weight, 0.2 * ratio_1_to_almost0) )
496
+ module.x_w.copy_( 1.0 - torch.pow(time_weight, 0.9 * ratio_1_to_almost0) )
497
+ module.x_k.copy_( 1.0 - (torch.pow(time_weight, 0.9 * ratio_1_to_almost0) + 0.4 * ratio_0_to_1) )
498
+ module.x_v.copy_( 1.0 - (torch.pow(time_weight, 0.4 * ratio_1_to_almost0) + 0.6 * ratio_0_to_1) )
499
+ module.x_a.copy_( 1.0 - torch.pow(time_weight, 0.9 * ratio_1_to_almost0) )
500
+ module.x_g.copy_( 1.0 - torch.pow(time_weight, 0.2 * ratio_1_to_almost0) )
501
+
502
+ def ortho_init(x, scale):
503
+ with torch.no_grad():
504
+ shape = x.shape
505
+ if len(shape) == 2:
506
+ gain = math.sqrt(shape[0] / shape[1]) if shape[0] > shape[1] else 1
507
+ nn.init.orthogonal_(x, gain=gain * scale)
508
+ elif len(shape) == 3:
509
+ gain = math.sqrt(shape[1] / shape[2]) if shape[1] > shape[2] else 1
510
+ for i in range(shape[0]):
511
+ nn.init.orthogonal_(x[i], gain=gain * scale)
512
+ else:
513
+ assert False
514
+ return x
515
+
516
+ module.w0.copy_(decay_speed.reshape(1,1,attention_hidden_size) + 0.5) # !!! 0.5 comes from F.softplus !!!
517
+ module.w1.zero_()
518
+ ortho_init(module.w2, 0.1)
519
+
520
+ module.a0.zero_()
521
+ module.a1.zero_()
522
+ ortho_init(module.a2, 0.1)
523
+
524
+ module.v0.copy_(1.0)
525
+ module.v1.zero_()
526
+ ortho_init(module.v2, 0.1)
527
+
528
+ module.g1.zero_()
529
+ ortho_init(module.g2, 0.1)
530
+
531
+ self.k_k.copy_(0.85)
532
+ self.k_a.copy_(1.0)
533
+ self.r_k.zero_()
534
+
535
+ module.receptance.weight.data.uniform_(-0.5/(hidden_size**0.5), 0.5/(attention_hidden_size**0.5))
536
+ module.key.weight.data.uniform_(-0.05/(hidden_size**0.5), 0.05/(attention_hidden_size**0.5))
537
+ module.value.weight.data.uniform_(-0.5/(hidden_size**0.5), 0.5/(attention_hidden_size**0.5))
538
+ module.output.weight.data.zero_()
539
+
540
+ elif isinstance(module, Rwkv7FeedForward):
541
+ layer_id = module.layer_id
542
+ num_hidden_layers = module.config.num_hidden_layers
543
+ hidden_size = module.config.hidden_size
544
+
545
+ ratio_1_to_almost0 = 1.0 - (layer_id / num_hidden_layers) # 1 to ~0
546
+
547
+ time_weight = torch.tensor(
548
+ [i / hidden_size for i in range(hidden_size)],
549
+ dtype=module.x_k.dtype,
550
+ device=module.x_k.device,
551
+ )
552
+ time_weight = time_weight[None, None, :]
553
+
554
+ with torch.no_grad():
555
+ module.x_k.copy_( 1.0 - torch.pow(time_weight, ratio_1_to_almost0**4) )
556
+
557
+ self.key.weight.data.uniform_(-0.5/(hidden_size**0.5), 0.5/(hidden_size**0.5))
558
+ self.value.weight.data.zero_()
559
+
560
+ @dataclass
561
+ class Rwkv7Output(ModelOutput):
562
+ """
563
+ Class for the RWKV model outputs.
564
+ Args:
565
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
566
+ Sequence of hidden-states at the output of the last layer of the model.
567
+ state (list of five `torch.FloatTensor` of shape `(batch_size, hidden_size, num_hidden_layers)`):
568
+ The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
569
+ avoid providing the old `input_ids`.
570
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
571
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
572
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of
573
+ the model at the output of each layer plus the optional initial embedding outputs.
574
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
575
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
576
+ sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
577
+ the self-attention heads.
578
+ """
579
+
580
+ last_hidden_state: torch.FloatTensor = None
581
+ state: Optional[List[torch.FloatTensor]] = None
582
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
583
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
584
+
585
+
586
+ @dataclass
587
+ class Rwkv7CausalLMOutput(ModelOutput):
588
+ """
589
+ Base class for causal language model (or autoregressive) outputs.
590
+ Args:
591
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
592
+ Language modeling loss (for next-token prediction).
593
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
594
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
595
+ state (list of five `torch.FloatTensor` of shape `(batch_size, hidden_size, num_hidden_layers)`):
596
+ The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
597
+ avoid providing the old `input_ids`.
598
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
599
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
600
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of
601
+ the model at the output of each layer plus the optional initial embedding outputs.
602
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
603
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
604
+ sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
605
+ the self-attention heads.
606
+ """
607
+
608
+ loss: Optional[torch.FloatTensor] = None
609
+ logits: torch.FloatTensor = None
610
+ state: Optional[List[torch.FloatTensor]] = None
611
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
612
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
613
+
614
+
615
+ RWKV7_START_DOCSTRING = r"""
616
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
617
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
618
+ etc.) This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module)
619
+ subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to
620
+ general usage and behavior.
621
+ Parameters:
622
+ config ([`Rwkv7Config`]): Model configuration class with all the parameters of the model.
623
+ Initializing with a config file does not load the weights associated with the model, only the
624
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
625
+ """
626
+
627
+ RWKV7_INPUTS_DOCSTRING = r"""
628
+ Args:
629
+ input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
630
+ `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
631
+ `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input
632
+ sequence tokens in the vocabulary. If `past_key_values` is used, only `input_ids` that do not have their
633
+ past calculated should be passed as `input_ids`. Indices can be obtained using [`AutoTokenizer`]. See
634
+ [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input
635
+ IDs?](../glossary#input-ids)
636
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
637
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
638
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
639
+ model's internal embedding lookup matrix.
640
+ state (tuple of five `torch.FloatTensor` of shape `(batch_size, hidden_size, num_hidden_layers)`, *optional*):
641
+ If passed along, the model uses the previous state in all the blocks (which will give the output for the
642
+ `input_ids` provided as if the model add `state_input_ids + input_ids` as context).
643
+ use_cache (`bool`, *optional*):
644
+ If set to `True`, the last state is returned and can be used to quickly generate the next logits.
645
+ output_attentions (`bool`, *optional*):
646
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
647
+ tensors for more detail.
648
+ output_hidden_states (`bool`, *optional*):
649
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
650
+ more detail.
651
+ return_dict (`bool`, *optional*):
652
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
653
+ """
654
+
655
+
656
+ @add_start_docstrings(
657
+ "The bare RWKV7 Model transformer outputting raw hidden-states without any specific head on top.",
658
+ RWKV7_START_DOCSTRING,
659
+ )
660
+ class Rwkv7Model(Rwkv7PreTrainedModel):
661
+ def __init__(self, config):
662
+ super().__init__(config)
663
+
664
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
665
+ self.pre_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
666
+ self.blocks = nn.ModuleList([Rwkv7Block(config, layer_id=idx) for idx in range(config.num_hidden_layers)])
667
+ self.ln_out = nn.LayerNorm(config.hidden_size)
668
+
669
+ self.gradient_checkpointing = False
670
+
671
+ # Initialize weights and apply final processing
672
+ self.post_init()
673
+
674
+ def get_input_embeddings(self):
675
+ return self.embeddings
676
+
677
+ def set_input_embeddings(self, new_embeddings):
678
+ self.embeddings = new_embeddings
679
+
680
+ @add_start_docstrings_to_model_forward(RWKV7_INPUTS_DOCSTRING)
681
+ @add_code_sample_docstrings(
682
+ checkpoint=_CHECKPOINT_FOR_DOC,
683
+ output_type=Rwkv7Output,
684
+ config_class=_CONFIG_FOR_DOC,
685
+ )
686
+ def forward(
687
+ self,
688
+ input_ids: Optional[torch.LongTensor] = None,
689
+ attention_mask: Optional[torch.LongTensor] = None, # noqa
690
+ inputs_embeds: Optional[torch.FloatTensor] = None,
691
+ state: Optional[List[torch.FloatTensor]] = None,
692
+ use_cache: Optional[bool] = None,
693
+ output_attentions: Optional[bool] = None,
694
+ output_hidden_states: Optional[bool] = None,
695
+ return_dict: Optional[bool] = None,
696
+ ) -> Union[Tuple, Rwkv7Output]:
697
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
698
+ output_hidden_states = (
699
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
700
+ )
701
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
702
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
703
+
704
+ if input_ids is not None and inputs_embeds is not None:
705
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
706
+ elif input_ids is None and inputs_embeds is None:
707
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
708
+
709
+ if inputs_embeds is None:
710
+ inputs_embeds = self.embeddings(input_ids)
711
+
712
+ if state is None:
713
+ state = []
714
+ head_size = self.config.head_size
715
+ num_heads = self.config.attention_hidden_size // head_size
716
+ state_attn_x = torch.zeros(
717
+ (self.config.num_hidden_layers, inputs_embeds.size(0), self.config.hidden_size),
718
+ dtype=inputs_embeds.dtype,
719
+ requires_grad=False,
720
+ device=inputs_embeds.device,
721
+ ).contiguous()
722
+ state_attn_vk = torch.zeros(
723
+ (
724
+ self.config.num_hidden_layers,
725
+ inputs_embeds.size(0),
726
+ num_heads,
727
+ head_size,
728
+ head_size,
729
+ ),
730
+ dtype=torch.float32,
731
+ requires_grad=False,
732
+ device=inputs_embeds.device,
733
+ ).contiguous()
734
+ state_ffn_x = torch.zeros(
735
+ (self.config.num_hidden_layers, inputs_embeds.size(0), self.config.hidden_size),
736
+ dtype=inputs_embeds.dtype,
737
+ requires_grad=False,
738
+ device=inputs_embeds.device,
739
+ ).contiguous()
740
+ state.append(state_attn_x)
741
+ state.append(state_attn_vk)
742
+ state.append(state_ffn_x)
743
+
744
+ seq_mode = inputs_embeds.shape[1] > 1
745
+ hidden_states = self.pre_ln(inputs_embeds)
746
+ v_first = None
747
+
748
+ all_self_attentions = () if output_attentions else None
749
+ all_hidden_states = () if output_hidden_states else None
750
+ for idx, block in enumerate(self.blocks):
751
+ hidden_states, state, v_first, attentions = block(
752
+ hidden_states, state=state, v_first=v_first, use_cache=use_cache, output_attentions=output_attentions, seq_mode=seq_mode
753
+ )
754
+
755
+ if output_hidden_states:
756
+ all_hidden_states = all_hidden_states + (hidden_states,)
757
+
758
+ if output_attentions:
759
+ all_self_attentions = all_self_attentions + (attentions,)
760
+
761
+ hidden_states = self.ln_out(hidden_states)
762
+
763
+ if output_hidden_states:
764
+ all_hidden_states = all_hidden_states + (hidden_states,)
765
+
766
+ if not return_dict:
767
+ return (hidden_states, state, all_hidden_states, all_self_attentions)
768
+
769
+ return Rwkv7Output(
770
+ last_hidden_state=hidden_states,
771
+ state=state,
772
+ hidden_states=all_hidden_states, # None
773
+ attentions=all_self_attentions, # None
774
+ )
775
+
776
+ # copied from HuggingFace https://github.com/huggingface/transformers/blob/main/src/transformers/models/rwkv/modeling_rwkv.py
777
+ @add_start_docstrings(
778
+ """
779
+ The RWKV7 Model transformer with a language modeling head on top (linear layer with weights tied to the input
780
+ embeddings).
781
+ """,
782
+ RWKV7_START_DOCSTRING,
783
+ )
784
+ class Rwkv7ForCausalLM(Rwkv7PreTrainedModel, GenerationMixin):
785
+ _tied_weights_keys = ["head.weight"]
786
+
787
+ def __init__(self, config):
788
+ super().__init__(config)
789
+ self.model = Rwkv7Model(config)
790
+ self.head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
791
+
792
+ # Initialize weights and apply final processing
793
+ self.post_init()
794
+
795
+ def get_output_embeddings(self):
796
+ return self.head
797
+
798
+ def set_output_embeddings(self, new_embeddings):
799
+ self.head = new_embeddings
800
+
801
+ def prepare_inputs_for_generation(self, input_ids, state=None, inputs_embeds=None, **kwargs):
802
+ # only last token for inputs_ids if the state is passed along.
803
+ if state is not None:
804
+ input_ids = input_ids[:, -1].unsqueeze(-1)
805
+
806
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
807
+ if inputs_embeds is not None and state is None:
808
+ model_inputs = {"inputs_embeds": inputs_embeds}
809
+ else:
810
+ model_inputs = {"input_ids": input_ids}
811
+
812
+ model_inputs["state"] = state
813
+ return model_inputs
814
+
815
+ @add_start_docstrings_to_model_forward(RWKV7_INPUTS_DOCSTRING)
816
+ @add_code_sample_docstrings(
817
+ checkpoint=_CHECKPOINT_FOR_DOC,
818
+ output_type=Rwkv7CausalLMOutput,
819
+ config_class=_CONFIG_FOR_DOC,
820
+ )
821
+ def forward(
822
+ self,
823
+ input_ids: Optional[torch.LongTensor] = None,
824
+ attention_mask: Optional[torch.LongTensor] = None,
825
+ inputs_embeds: Optional[torch.FloatTensor] = None,
826
+ state: Optional[List[torch.FloatTensor]] = None,
827
+ labels: Optional[torch.LongTensor] = None,
828
+ use_cache: Optional[bool] = None,
829
+ output_attentions: Optional[bool] = None,
830
+ output_hidden_states: Optional[bool] = None,
831
+ return_dict: Optional[bool] = None,
832
+ ) -> Union[Tuple, Rwkv7CausalLMOutput]:
833
+ r"""
834
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
835
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
836
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
837
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
838
+ """
839
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
840
+
841
+ outputs = self.model(
842
+ input_ids,
843
+ inputs_embeds=inputs_embeds,
844
+ state=state,
845
+ use_cache=use_cache,
846
+ output_attentions=output_attentions,
847
+ output_hidden_states=output_hidden_states,
848
+ return_dict=return_dict,
849
+ )
850
+ hidden_states = outputs[0]
851
+
852
+ logits = self.head(hidden_states)
853
+
854
+ loss = None
855
+ if labels is not None:
856
+ # move labels to correct device to enable model parallelism
857
+ labels = labels.to(logits.device)
858
+ # Shift so that tokens < n predict n
859
+ shift_logits = logits[..., :-1, :].contiguous()
860
+ shift_labels = labels[..., 1:].contiguous()
861
+ # Flatten the tokens
862
+ loss_fct = CrossEntropyLoss()
863
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
864
+
865
+ if not return_dict:
866
+ output = (logits,) + outputs[1:]
867
+ return ((loss,) + output) if loss is not None else output
868
+
869
+ return Rwkv7CausalLMOutput(
870
+ loss=loss,
871
+ logits=logits,
872
+ state=outputs.state,
873
+ hidden_states=outputs.hidden_states,
874
+ attentions=outputs.attentions,
875
+ )
rwkv_vocab_v20230424.txt ADDED
The diff for this file is too large to render. See raw diff
 
special_tokens_map.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "<s>",
3
+ "eos_token": "<s>",
4
+ "unk_token": "<s>"
5
+ }
tokenizer_config.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "added_tokens_decoder": {
4
+ "0": {
5
+ "content": "<s>",
6
+ "lstrip": false,
7
+ "normalized": false,
8
+ "rstrip": false,
9
+ "single_word": false,
10
+ "special": true
11
+ }
12
+ },
13
+ "auto_map": {
14
+ "AutoTokenizer": [
15
+ "hf_rwkv_tokenizer.Rwkv6Tokenizer",
16
+ null
17
+ ]
18
+ },
19
+ "bos_token": "<s>",
20
+ "clean_up_tokenization_spaces": false,
21
+ "eos_token": "<s>",
22
+ "model_max_length": 1000000000000000019884624838656,
23
+ "tokenizer_class": "Rwkv6Tokenizer",
24
+ "unk_token": "<s>",
25
+ "use_fast": false
26
+ }