yhirokawa commited on
Commit
65176a6
·
0 Parent(s):

initial commit

Browse files
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ Apache License
3
+ Version 2.0, January 2004
4
+ http://www.apache.org/licenses/
5
+
6
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7
+
8
+ 1. Definitions.
9
+
10
+ "License" shall mean the terms and conditions for use, reproduction,
11
+ and distribution as defined by Sections 1 through 9 of this document.
12
+
13
+ "Licensor" shall mean the copyright owner or entity authorized by
14
+ the copyright owner that is granting the License.
15
+
16
+ "Legal Entity" shall mean the union of the acting entity and all
17
+ other entities that control, are controlled by, or are under common
18
+ control with that entity. For the purposes of this definition,
19
+ "control" means (i) the power, direct or indirect, to cause the
20
+ direction or management of such entity, whether by contract or
21
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
22
+ outstanding shares, or (iii) beneficial ownership of such entity.
23
+
24
+ "You" (or "Your") shall mean an individual or Legal Entity
25
+ exercising permissions granted by this License.
26
+
27
+ "Source" form shall mean the preferred form for making modifications,
28
+ including but not limited to software source code, documentation
29
+ source, and configuration files.
30
+
31
+ "Object" form shall mean any form resulting from mechanical
32
+ transformation or translation of a Source form, including but
33
+ not limited to compiled object code, generated documentation,
34
+ and conversions to other media types.
35
+
36
+ "Work" shall mean the work of authorship, whether in Source or
37
+ Object form, made available under the License, as indicated by a
38
+ copyright notice that is included in or attached to the work
39
+ (an example is provided in the Appendix below).
40
+
41
+ "Derivative Works" shall mean any work, whether in Source or Object
42
+ form, that is based on (or derived from) the Work and for which the
43
+ editorial revisions, annotations, elaborations, or other modifications
44
+ represent, as a whole, an original work of authorship. For the purposes
45
+ of this License, Derivative Works shall not include works that remain
46
+ separable from, or merely link (or bind by name) to the interfaces of,
47
+ the Work and Derivative Works thereof.
48
+
49
+ "Contribution" shall mean any work of authorship, including
50
+ the original version of the Work and any modifications or additions
51
+ to that Work or Derivative Works thereof, that is intentionally
52
+ submitted to Licensor for inclusion in the Work by the copyright owner
53
+ or by an individual or Legal Entity authorized to submit on behalf of
54
+ the copyright owner. For the purposes of this definition, "submitted"
55
+ means any form of electronic, verbal, or written communication sent
56
+ to the Licensor or its representatives, including but not limited to
57
+ communication on electronic mailing lists, source code control systems,
58
+ and issue tracking systems that are managed by, or on behalf of, the
59
+ Licensor for the purpose of discussing and improving the Work, but
60
+ excluding communication that is conspicuously marked or otherwise
61
+ designated in writing by the copyright owner as "Not a Contribution."
62
+
63
+ "Contributor" shall mean Licensor and any individual or Legal Entity
64
+ on behalf of whom a Contribution has been received by Licensor and
65
+ subsequently incorporated within the Work.
66
+
67
+ 2. Grant of Copyright License. Subject to the terms and conditions of
68
+ this License, each Contributor hereby grants to You a perpetual,
69
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70
+ copyright license to reproduce, prepare Derivative Works of,
71
+ publicly display, publicly perform, sublicense, and distribute the
72
+ Work and such Derivative Works in Source or Object form.
73
+
74
+ 3. Grant of Patent License. Subject to the terms and conditions of
75
+ this License, each Contributor hereby grants to You a perpetual,
76
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77
+ (except as stated in this section) patent license to make, have made,
78
+ use, offer to sell, sell, import, and otherwise transfer the Work,
79
+ where such license applies only to those patent claims licensable
80
+ by such Contributor that are necessarily infringed by their
81
+ Contribution(s) alone or by combination of their Contribution(s)
82
+ with the Work to which such Contribution(s) was submitted. If You
83
+ institute patent litigation against any entity (including a
84
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
85
+ or a Contribution incorporated within the Work constitutes direct
86
+ or contributory patent infringement, then any patent licenses
87
+ granted to You under this License for that Work shall terminate
88
+ as of the date such litigation is filed.
89
+
90
+ 4. Redistribution. You may reproduce and distribute copies of the
91
+ Work or Derivative Works thereof in any medium, with or without
92
+ modifications, and in Source or Object form, provided that You
93
+ meet the following conditions:
94
+
95
+ (a) You must give any other recipients of the Work or
96
+ Derivative Works a copy of this License; and
97
+
98
+ (b) You must cause any modified files to carry prominent notices
99
+ stating that You changed the files; and
100
+
101
+ (c) You must retain, in the Source form of any Derivative Works
102
+ that You distribute, all copyright, patent, trademark, and
103
+ attribution notices from the Source form of the Work,
104
+ excluding those notices that do not pertain to any part of
105
+ the Derivative Works; and
106
+
107
+ (d) If the Work includes a "NOTICE" text file as part of its
108
+ distribution, then any Derivative Works that You distribute must
109
+ include a readable copy of the attribution notices contained
110
+ within such NOTICE file, excluding those notices that do not
111
+ pertain to any part of the Derivative Works, in at least one
112
+ of the following places: within a NOTICE text file distributed
113
+ as part of the Derivative Works; within the Source form or
114
+ documentation, if provided along with the Derivative Works; or,
115
+ within a display generated by the Derivative Works, if and
116
+ wherever such third-party notices normally appear. The contents
117
+ of the NOTICE file are for informational purposes only and
118
+ do not modify the License. You may add Your own attribution
119
+ notices within Derivative Works that You distribute, alongside
120
+ or as an addendum to the NOTICE text from the Work, provided
121
+ that such additional attribution notices cannot be construed
122
+ as modifying the License.
123
+
124
+ You may add Your own copyright statement to Your modifications and
125
+ may provide additional or different license terms and conditions
126
+ for use, reproduction, or distribution of Your modifications, or
127
+ for any such Derivative Works as a whole, provided Your use,
128
+ reproduction, and distribution of the Work otherwise complies with
129
+ the conditions stated in this License.
130
+
131
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
132
+ any Contribution intentionally submitted for inclusion in the Work
133
+ by You to the Licensor shall be under the terms and conditions of
134
+ this License, without any additional terms or conditions.
135
+ Notwithstanding the above, nothing herein shall supersede or modify
136
+ the terms of any separate license agreement you may have executed
137
+ with Licensor regarding such Contributions.
138
+
139
+ 6. Trademarks. This License does not grant permission to use the trade
140
+ names, trademarks, service marks, or product names of the Licensor,
141
+ except as required for reasonable and customary use in describing the
142
+ origin of the Work and reproducing the content of the NOTICE file.
143
+
144
+ 7. Disclaimer of Warranty. Unless required by applicable law or
145
+ agreed to in writing, Licensor provides the Work (and each
146
+ Contributor provides its Contributions) on an "AS IS" BASIS,
147
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148
+ implied, including, without limitation, any warranties or conditions
149
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150
+ PARTICULAR PURPOSE. You are solely responsible for determining the
151
+ appropriateness of using or redistributing the Work and assume any
152
+ risks associated with Your exercise of permissions under this License.
153
+
154
+ 8. Limitation of Liability. In no event and under no legal theory,
155
+ whether in tort (including negligence), contract, or otherwise,
156
+ unless required by applicable law (such as deliberate and grossly
157
+ negligent acts) or agreed to in writing, shall any Contributor be
158
+ liable to You for damages, including any direct, indirect, special,
159
+ incidental, or consequential damages of any character arising as a
160
+ result of this License or out of the use or inability to use the
161
+ Work (including but not limited to damages for loss of goodwill,
162
+ work stoppage, computer failure or malfunction, or any and all
163
+ other commercial damages or losses), even if such Contributor
164
+ has been advised of the possibility of such damages.
165
+
166
+ 9. Accepting Warranty or Additional Liability. While redistributing
167
+ the Work or Derivative Works thereof, You may choose to offer,
168
+ and charge a fee for, acceptance of support, warranty, indemnity,
169
+ or other liability obligations and/or rights consistent with this
170
+ License. However, in accepting such obligations, You may act only
171
+ on Your own behalf and on Your sole responsibility, not on behalf
172
+ of any other Contributor, and only if You agree to indemnify,
173
+ defend, and hold each Contributor harmless for any liability
174
+ incurred by, or claims asserted against, such Contributor by reason
175
+ of your accepting any such warranty or additional liability.
176
+
177
+ END OF TERMS AND CONDITIONS
178
+
179
+ APPENDIX: How to apply the Apache License to your work.
180
+
181
+ To apply the Apache License to your work, attach the following
182
+ boilerplate notice, with the fields enclosed by brackets "[]"
183
+ replaced with your own identifying information. (Don't include
184
+ the brackets!) The text should be enclosed in the appropriate
185
+ comment syntax for the file format. We also recommend that a
186
+ file or class name and description of purpose be included on the
187
+ same "printed page" as the copyright notice for easier
188
+ identification within third-party archives.
189
+
190
+ Copyright 2025 Preferred Elements, Inc.
191
+
192
+ Licensed under the Apache License, Version 2.0 (the "License");
193
+ you may not use this file except in compliance with the License.
194
+ You may obtain a copy of the License at
195
+
196
+ http://www.apache.org/licenses/LICENSE-2.0
197
+
198
+ Unless required by applicable law or agreed to in writing, software
199
+ distributed under the License is distributed on an "AS IS" BASIS,
200
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201
+ See the License for the specific language governing permissions and
202
+ limitations under the License.
README.md ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ language:
4
+ - en
5
+ - ja
6
+ pipeline_tag: text-generation
7
+ library_name: transformers
8
+ ---
9
+
10
+ # PLaMo 2 1B
11
+
12
+
13
+ ## Model Description
14
+ PLaMo 2 1B is a 1B model pre-trained on English and Japanese datasets, developed by Preferred Elements, Inc.
15
+
16
+ PLaMo 2 1B is released under Apache License version 2.0.
17
+
18
+ **NOTE**: This model has **NOT** been instruction-tuned for chat dialog or other downstream tasks.
19
+
20
+
21
+ ## Usage
22
+
23
+ ### Requirements
24
+
25
+ ```
26
+ numpy>=1.26.4
27
+ numba>=0.60.0
28
+ torch>=2.4.1
29
+ transformers>=4.44.2
30
+ mamba_ssm>=2.2.2
31
+ causal_conv1d>=1.4.0
32
+ ```
33
+
34
+ ### Use a pipeline as a high-level helper
35
+
36
+ ```python
37
+ import transformers
38
+ pipeline = transformers.pipeline("text-generation", model="pfnet/plamo-2-1b", trust_remote_code=True)
39
+ print(pipeline("The future of artificial intelligence technology is ", max_new_tokens=32))
40
+ ```
41
+
42
+ ### Load model directly
43
+
44
+ ```python
45
+ from transformers import AutoTokenizer, AutoModelForCausalLM
46
+ tokenizer = AutoTokenizer.from_pretrained("pfnet/plamo-2-1b", trust_remote_code=True)
47
+ model = AutoModelForCausalLM.from_pretrained("pfnet/plamo-2-1b", trust_remote_code=True)
48
+ text = "これからの人工知能技術は"
49
+ input_ids = tokenizer(text, return_tensors="pt").input_ids
50
+ generated_tokens = model.generate(
51
+ inputs=input_ids,
52
+ max_new_tokens=32,
53
+ do_sample=True,
54
+ top_k=50,
55
+ top_p=0.95,
56
+ temperature=1.0,
57
+ )[0]
58
+ generated_text = tokenizer.decode(generated_tokens)
59
+ print(generated_text)
60
+ ```
61
+
62
+
63
+ ## Model Details
64
+
65
+ - Model size: 1B
66
+ - Trained tokens: 4T tokens
67
+ - Developed by: Preferred Elements, Inc.
68
+ - Model type: Causal decoder-only
69
+ - Language(s): English, Japanese
70
+ - License: Apache License version 2.0
71
+
72
+
73
+ ## Training Dataset
74
+
75
+ We trained PLaMo 2 1B in two phases, phase 1 with 3.5T tokens and phase 2 with 0.5T tokens.
76
+ The percentage of datasets in each phase is shown in the following table.
77
+
78
+ ||3.5T (phase 1)|0.5T (phase 2)|Tokens|
79
+ |---|:---:|:---:|:---:|
80
+ |English|45 %|35 %|1.75 T|
81
+ |Japanese|30 %|40 %|1.25 T|
82
+ |Coding|15 %|15 %|0.6 T|
83
+ |Other|10 %|10 %|0.4 T|
84
+
85
+
86
+ ## Tokenizer
87
+
88
+ PLaMo 2 1B tokenizer is optimized by numba, which is JIT compiler for numerical functions.
89
+ The tokenizer is trained on a subset of the datasets for model pre-training.
90
+
91
+
92
+ ## Tech Blog
93
+
94
+ - (JA) https://tech.preferred.jp/ja/blog/plamo-2/
95
+ - (JA) https://tech.preferred.jp/ja/blog/plamo-2-tokenizer/
96
+
97
+
98
+ ## Bias, Risks, and Limitations
99
+
100
+ PLaMo 2 1B is a new technology that carries risks with use. Testing conducted to date has been in English and Japanese, and has not covered, nor could it cover all scenarios. For these reasons, as with all LLMs, PLaMo 2 1B’s potential outputs cannot be predicted in advance, and the model may in some instances produce inaccurate, biased or other objectionable responses to user prompts. Therefore, before deploying any applications of PLaMo 2 1B, developers should perform safety testing and tuning tailored to their specific applications of the model.
101
+
102
+
103
+ ## Acknowledgement
104
+
105
+ This model is trained under the project, “Research and Development Project of the Enhanced Infrastructures for Post 5G Information and Communication System” (JPNP 20017), subsidized by the New Energy and Industrial Technology Development Organization (NEDO).
106
+
107
+
108
+ ## AI policies for Preferred Networks, Inc. group
109
+
110
+ - (EN) https://www.preferred.jp/en/company/aipolicy/
111
+ - (JA) https://www.preferred.jp/ja/company/aipolicy/
config.json ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "PlamoForCausalLM"
4
+ ],
5
+ "attention_window_size": 2048,
6
+ "auto_map": {
7
+ "AutoConfig": "modeling_plamo.PlamoConfig",
8
+ "AutoModelForCausalLM": "modeling_plamo.PlamoForCausalLM"
9
+ },
10
+ "bos_token_id": 1,
11
+ "capacity_factor": 1.0,
12
+ "eos_token_id": 2,
13
+ "eval_attention_n_bit": null,
14
+ "eval_mlp_n_bit": null,
15
+ "expert_dropout": 0.0,
16
+ "fp8_accum_dtype": "bfloat16",
17
+ "group_size": 1024,
18
+ "hidden_size": 2048,
19
+ "hidden_size_per_head": 128,
20
+ "image_feature_size": null,
21
+ "image_proj_type": "linear",
22
+ "image_token_id": null,
23
+ "intermediate_size": 8192,
24
+ "k_expert": null,
25
+ "linear_type": "fp8",
26
+ "mamba_chunk_size": 256,
27
+ "mamba_d_conv": 4,
28
+ "mamba_d_state": 64,
29
+ "mamba_enabled": true,
30
+ "mamba_num_heads": 32,
31
+ "mamba_step": 2,
32
+ "max_position_embeddings": 10485760,
33
+ "model_type": "plamo",
34
+ "n_expert": null,
35
+ "num_attention_heads": 16,
36
+ "num_hidden_layers": 16,
37
+ "num_key_value_heads": 1,
38
+ "rms_norm_eps": 1e-06,
39
+ "shared_intermediate_size": null,
40
+ "sliding_window": 2048,
41
+ "sparse_intermediate_size": null,
42
+ "sparse_step": null,
43
+ "tokenizer_class": "PlamoTokenizer",
44
+ "torch_dtype": "float32",
45
+ "transformers_version": "4.44.2",
46
+ "use_cache": true,
47
+ "use_predefined_initial_state": false,
48
+ "vocab_size": 100000
49
+ }
generation_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "transformers_version": "4.44.2"
6
+ }
model-00001-of-00007.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:306bfd09ff22b5c044abfbc310b0a916a85a450ff52beaa90e245bb67b4cb1e9
3
+ size 4060618320
model-00002-of-00007.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:759d54a62b8f8124031bf4048b2b7e9b9fe3439e75ce772e48562369768836a8
3
+ size 285214592
model-00003-of-00007.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c7172210250c2b4030b1bc9c75a68224844e770838159e939bb6c68593d2e235
3
+ size 819200136
model-00004-of-00007.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:017d164149a8e31acdde2f8c6dd45568d7a8900951f83cb5fbcebac64c662695
3
+ size 621448
model-00005-of-00007.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6bfe7b7d45c40c75d7ff7b362fd88bf34e23aedf6a637df4eafc9f7d6231f8ca
3
+ size 131976
model-00006-of-00007.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3dadc185925d0be7f495c51924546b0f07f93550cf4222da96dedd335beb3b8f
3
+ size 5296
model-00007-of-00007.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:60d95b10b6e140a9626a7058d5038528f2ff80148dc4569b881db56052046509
3
+ size 40
model.safetensors.index.json ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "weight_map": {
3
+ "model.layers.layers.0.mixer.in_proj.weight": "model-00001-of-00007.safetensors",
4
+ "model.layers.layers.0.mixer.conv1d.weight": "model-00001-of-00007.safetensors",
5
+ "model.layers.layers.0.mixer.bcdt_proj.weight": "model-00001-of-00007.safetensors",
6
+ "model.layers.layers.0.mixer.out_proj.weight": "model-00001-of-00007.safetensors",
7
+ "model.layers.layers.2.mixer.in_proj.weight": "model-00001-of-00007.safetensors",
8
+ "model.layers.layers.2.mixer.conv1d.weight": "model-00001-of-00007.safetensors",
9
+ "model.layers.layers.2.mixer.bcdt_proj.weight": "model-00001-of-00007.safetensors",
10
+ "model.layers.layers.2.mixer.out_proj.weight": "model-00001-of-00007.safetensors",
11
+ "model.layers.layers.4.mixer.in_proj.weight": "model-00001-of-00007.safetensors",
12
+ "model.layers.layers.4.mixer.conv1d.weight": "model-00001-of-00007.safetensors",
13
+ "model.layers.layers.4.mixer.bcdt_proj.weight": "model-00001-of-00007.safetensors",
14
+ "model.layers.layers.4.mixer.out_proj.weight": "model-00001-of-00007.safetensors",
15
+ "model.layers.layers.6.mixer.in_proj.weight": "model-00001-of-00007.safetensors",
16
+ "model.layers.layers.6.mixer.conv1d.weight": "model-00001-of-00007.safetensors",
17
+ "model.layers.layers.6.mixer.bcdt_proj.weight": "model-00001-of-00007.safetensors",
18
+ "model.layers.layers.6.mixer.out_proj.weight": "model-00001-of-00007.safetensors",
19
+ "model.layers.layers.8.mixer.in_proj.weight": "model-00001-of-00007.safetensors",
20
+ "model.layers.layers.8.mixer.conv1d.weight": "model-00001-of-00007.safetensors",
21
+ "model.layers.layers.8.mixer.bcdt_proj.weight": "model-00001-of-00007.safetensors",
22
+ "model.layers.layers.8.mixer.out_proj.weight": "model-00001-of-00007.safetensors",
23
+ "model.layers.layers.10.mixer.in_proj.weight": "model-00001-of-00007.safetensors",
24
+ "model.layers.layers.10.mixer.conv1d.weight": "model-00001-of-00007.safetensors",
25
+ "model.layers.layers.10.mixer.bcdt_proj.weight": "model-00001-of-00007.safetensors",
26
+ "model.layers.layers.10.mixer.out_proj.weight": "model-00001-of-00007.safetensors",
27
+ "model.layers.layers.12.mixer.in_proj.weight": "model-00001-of-00007.safetensors",
28
+ "model.layers.layers.12.mixer.conv1d.weight": "model-00001-of-00007.safetensors",
29
+ "model.layers.layers.12.mixer.bcdt_proj.weight": "model-00001-of-00007.safetensors",
30
+ "model.layers.layers.12.mixer.out_proj.weight": "model-00001-of-00007.safetensors",
31
+ "model.layers.layers.14.mixer.in_proj.weight": "model-00001-of-00007.safetensors",
32
+ "model.layers.layers.14.mixer.conv1d.weight": "model-00001-of-00007.safetensors",
33
+ "model.layers.layers.14.mixer.bcdt_proj.weight": "model-00001-of-00007.safetensors",
34
+ "model.layers.layers.14.mixer.out_proj.weight": "model-00001-of-00007.safetensors",
35
+ "model.layers.layers.1.mixer.qkv_proj.weight": "model-00002-of-00007.safetensors",
36
+ "model.layers.layers.1.mixer.o_proj.weight": "model-00002-of-00007.safetensors",
37
+ "model.layers.layers.3.mixer.qkv_proj.weight": "model-00002-of-00007.safetensors",
38
+ "model.layers.layers.3.mixer.o_proj.weight": "model-00002-of-00007.safetensors",
39
+ "model.layers.layers.5.mixer.qkv_proj.weight": "model-00002-of-00007.safetensors",
40
+ "model.layers.layers.5.mixer.o_proj.weight": "model-00002-of-00007.safetensors",
41
+ "model.layers.layers.7.mixer.qkv_proj.weight": "model-00002-of-00007.safetensors",
42
+ "model.layers.layers.7.mixer.o_proj.weight": "model-00002-of-00007.safetensors",
43
+ "model.layers.layers.9.mixer.qkv_proj.weight": "model-00002-of-00007.safetensors",
44
+ "model.layers.layers.9.mixer.o_proj.weight": "model-00002-of-00007.safetensors",
45
+ "model.layers.layers.11.mixer.qkv_proj.weight": "model-00002-of-00007.safetensors",
46
+ "model.layers.layers.11.mixer.o_proj.weight": "model-00002-of-00007.safetensors",
47
+ "model.layers.layers.13.mixer.qkv_proj.weight": "model-00002-of-00007.safetensors",
48
+ "model.layers.layers.13.mixer.o_proj.weight": "model-00002-of-00007.safetensors",
49
+ "model.layers.layers.15.mixer.qkv_proj.weight": "model-00002-of-00007.safetensors",
50
+ "model.layers.layers.15.mixer.o_proj.weight": "model-00002-of-00007.safetensors",
51
+ "model.embed_tokens.weight": "model-00003-of-00007.safetensors",
52
+ "model.layers.layers.0.mixer.dt_norm_weight": "model-00004-of-00007.safetensors",
53
+ "model.layers.layers.0.mixer.B_norm_weight": "model-00004-of-00007.safetensors",
54
+ "model.layers.layers.0.mixer.C_norm_weight": "model-00004-of-00007.safetensors",
55
+ "model.layers.layers.0.pre_mixer_norm.weight": "model-00004-of-00007.safetensors",
56
+ "model.layers.layers.0.post_mixer_norm.weight": "model-00004-of-00007.safetensors",
57
+ "model.layers.layers.0.pre_mlp_norm.weight": "model-00004-of-00007.safetensors",
58
+ "model.layers.layers.0.post_mlp_norm.weight": "model-00004-of-00007.safetensors",
59
+ "model.layers.layers.1.mixer.q_weight": "model-00004-of-00007.safetensors",
60
+ "model.layers.layers.1.mixer.k_weight": "model-00004-of-00007.safetensors",
61
+ "model.layers.layers.1.pre_mixer_norm.weight": "model-00004-of-00007.safetensors",
62
+ "model.layers.layers.1.post_mixer_norm.weight": "model-00004-of-00007.safetensors",
63
+ "model.layers.layers.1.pre_mlp_norm.weight": "model-00004-of-00007.safetensors",
64
+ "model.layers.layers.1.post_mlp_norm.weight": "model-00004-of-00007.safetensors",
65
+ "model.layers.layers.2.mixer.dt_norm_weight": "model-00004-of-00007.safetensors",
66
+ "model.layers.layers.2.mixer.B_norm_weight": "model-00004-of-00007.safetensors",
67
+ "model.layers.layers.2.mixer.C_norm_weight": "model-00004-of-00007.safetensors",
68
+ "model.layers.layers.2.pre_mixer_norm.weight": "model-00004-of-00007.safetensors",
69
+ "model.layers.layers.2.post_mixer_norm.weight": "model-00004-of-00007.safetensors",
70
+ "model.layers.layers.2.pre_mlp_norm.weight": "model-00004-of-00007.safetensors",
71
+ "model.layers.layers.2.post_mlp_norm.weight": "model-00004-of-00007.safetensors",
72
+ "model.layers.layers.3.mixer.q_weight": "model-00004-of-00007.safetensors",
73
+ "model.layers.layers.3.mixer.k_weight": "model-00004-of-00007.safetensors",
74
+ "model.layers.layers.3.pre_mixer_norm.weight": "model-00004-of-00007.safetensors",
75
+ "model.layers.layers.3.post_mixer_norm.weight": "model-00004-of-00007.safetensors",
76
+ "model.layers.layers.3.pre_mlp_norm.weight": "model-00004-of-00007.safetensors",
77
+ "model.layers.layers.3.post_mlp_norm.weight": "model-00004-of-00007.safetensors",
78
+ "model.layers.layers.4.mixer.dt_norm_weight": "model-00004-of-00007.safetensors",
79
+ "model.layers.layers.4.mixer.B_norm_weight": "model-00004-of-00007.safetensors",
80
+ "model.layers.layers.4.mixer.C_norm_weight": "model-00004-of-00007.safetensors",
81
+ "model.layers.layers.4.pre_mixer_norm.weight": "model-00004-of-00007.safetensors",
82
+ "model.layers.layers.4.post_mixer_norm.weight": "model-00004-of-00007.safetensors",
83
+ "model.layers.layers.4.pre_mlp_norm.weight": "model-00004-of-00007.safetensors",
84
+ "model.layers.layers.4.post_mlp_norm.weight": "model-00004-of-00007.safetensors",
85
+ "model.layers.layers.5.mixer.q_weight": "model-00004-of-00007.safetensors",
86
+ "model.layers.layers.5.mixer.k_weight": "model-00004-of-00007.safetensors",
87
+ "model.layers.layers.5.pre_mixer_norm.weight": "model-00004-of-00007.safetensors",
88
+ "model.layers.layers.5.post_mixer_norm.weight": "model-00004-of-00007.safetensors",
89
+ "model.layers.layers.5.pre_mlp_norm.weight": "model-00004-of-00007.safetensors",
90
+ "model.layers.layers.5.post_mlp_norm.weight": "model-00004-of-00007.safetensors",
91
+ "model.layers.layers.6.mixer.dt_norm_weight": "model-00004-of-00007.safetensors",
92
+ "model.layers.layers.6.mixer.B_norm_weight": "model-00004-of-00007.safetensors",
93
+ "model.layers.layers.6.mixer.C_norm_weight": "model-00004-of-00007.safetensors",
94
+ "model.layers.layers.6.pre_mixer_norm.weight": "model-00004-of-00007.safetensors",
95
+ "model.layers.layers.6.post_mixer_norm.weight": "model-00004-of-00007.safetensors",
96
+ "model.layers.layers.6.pre_mlp_norm.weight": "model-00004-of-00007.safetensors",
97
+ "model.layers.layers.6.post_mlp_norm.weight": "model-00004-of-00007.safetensors",
98
+ "model.layers.layers.7.mixer.q_weight": "model-00004-of-00007.safetensors",
99
+ "model.layers.layers.7.mixer.k_weight": "model-00004-of-00007.safetensors",
100
+ "model.layers.layers.7.pre_mixer_norm.weight": "model-00004-of-00007.safetensors",
101
+ "model.layers.layers.7.post_mixer_norm.weight": "model-00004-of-00007.safetensors",
102
+ "model.layers.layers.7.pre_mlp_norm.weight": "model-00004-of-00007.safetensors",
103
+ "model.layers.layers.7.post_mlp_norm.weight": "model-00004-of-00007.safetensors",
104
+ "model.layers.layers.8.mixer.dt_norm_weight": "model-00004-of-00007.safetensors",
105
+ "model.layers.layers.8.mixer.B_norm_weight": "model-00004-of-00007.safetensors",
106
+ "model.layers.layers.8.mixer.C_norm_weight": "model-00004-of-00007.safetensors",
107
+ "model.layers.layers.8.pre_mixer_norm.weight": "model-00004-of-00007.safetensors",
108
+ "model.layers.layers.8.post_mixer_norm.weight": "model-00004-of-00007.safetensors",
109
+ "model.layers.layers.8.pre_mlp_norm.weight": "model-00004-of-00007.safetensors",
110
+ "model.layers.layers.8.post_mlp_norm.weight": "model-00004-of-00007.safetensors",
111
+ "model.layers.layers.9.mixer.q_weight": "model-00004-of-00007.safetensors",
112
+ "model.layers.layers.9.mixer.k_weight": "model-00004-of-00007.safetensors",
113
+ "model.layers.layers.9.pre_mixer_norm.weight": "model-00004-of-00007.safetensors",
114
+ "model.layers.layers.9.post_mixer_norm.weight": "model-00004-of-00007.safetensors",
115
+ "model.layers.layers.9.pre_mlp_norm.weight": "model-00004-of-00007.safetensors",
116
+ "model.layers.layers.9.post_mlp_norm.weight": "model-00004-of-00007.safetensors",
117
+ "model.layers.layers.10.mixer.dt_norm_weight": "model-00004-of-00007.safetensors",
118
+ "model.layers.layers.10.mixer.B_norm_weight": "model-00004-of-00007.safetensors",
119
+ "model.layers.layers.10.mixer.C_norm_weight": "model-00004-of-00007.safetensors",
120
+ "model.layers.layers.10.pre_mixer_norm.weight": "model-00004-of-00007.safetensors",
121
+ "model.layers.layers.10.post_mixer_norm.weight": "model-00004-of-00007.safetensors",
122
+ "model.layers.layers.10.pre_mlp_norm.weight": "model-00004-of-00007.safetensors",
123
+ "model.layers.layers.10.post_mlp_norm.weight": "model-00004-of-00007.safetensors",
124
+ "model.layers.layers.11.mixer.q_weight": "model-00004-of-00007.safetensors",
125
+ "model.layers.layers.11.mixer.k_weight": "model-00004-of-00007.safetensors",
126
+ "model.layers.layers.11.pre_mixer_norm.weight": "model-00004-of-00007.safetensors",
127
+ "model.layers.layers.11.post_mixer_norm.weight": "model-00004-of-00007.safetensors",
128
+ "model.layers.layers.11.pre_mlp_norm.weight": "model-00004-of-00007.safetensors",
129
+ "model.layers.layers.11.post_mlp_norm.weight": "model-00004-of-00007.safetensors",
130
+ "model.layers.layers.12.mixer.dt_norm_weight": "model-00004-of-00007.safetensors",
131
+ "model.layers.layers.12.mixer.B_norm_weight": "model-00004-of-00007.safetensors",
132
+ "model.layers.layers.12.mixer.C_norm_weight": "model-00004-of-00007.safetensors",
133
+ "model.layers.layers.12.pre_mixer_norm.weight": "model-00004-of-00007.safetensors",
134
+ "model.layers.layers.12.post_mixer_norm.weight": "model-00004-of-00007.safetensors",
135
+ "model.layers.layers.12.pre_mlp_norm.weight": "model-00004-of-00007.safetensors",
136
+ "model.layers.layers.12.post_mlp_norm.weight": "model-00004-of-00007.safetensors",
137
+ "model.layers.layers.13.mixer.q_weight": "model-00004-of-00007.safetensors",
138
+ "model.layers.layers.13.mixer.k_weight": "model-00004-of-00007.safetensors",
139
+ "model.layers.layers.13.pre_mixer_norm.weight": "model-00004-of-00007.safetensors",
140
+ "model.layers.layers.13.post_mixer_norm.weight": "model-00004-of-00007.safetensors",
141
+ "model.layers.layers.13.pre_mlp_norm.weight": "model-00004-of-00007.safetensors",
142
+ "model.layers.layers.13.post_mlp_norm.weight": "model-00004-of-00007.safetensors",
143
+ "model.layers.layers.14.mixer.dt_norm_weight": "model-00004-of-00007.safetensors",
144
+ "model.layers.layers.14.mixer.B_norm_weight": "model-00004-of-00007.safetensors",
145
+ "model.layers.layers.14.mixer.C_norm_weight": "model-00004-of-00007.safetensors",
146
+ "model.layers.layers.14.pre_mixer_norm.weight": "model-00004-of-00007.safetensors",
147
+ "model.layers.layers.14.post_mixer_norm.weight": "model-00004-of-00007.safetensors",
148
+ "model.layers.layers.14.pre_mlp_norm.weight": "model-00004-of-00007.safetensors",
149
+ "model.layers.layers.14.post_mlp_norm.weight": "model-00004-of-00007.safetensors",
150
+ "model.layers.layers.15.mixer.q_weight": "model-00004-of-00007.safetensors",
151
+ "model.layers.layers.15.mixer.k_weight": "model-00004-of-00007.safetensors",
152
+ "model.layers.layers.15.pre_mixer_norm.weight": "model-00004-of-00007.safetensors",
153
+ "model.layers.layers.15.post_mixer_norm.weight": "model-00004-of-00007.safetensors",
154
+ "model.layers.layers.15.pre_mlp_norm.weight": "model-00004-of-00007.safetensors",
155
+ "model.layers.layers.15.post_mlp_norm.weight": "model-00004-of-00007.safetensors",
156
+ "model.norm.weight": "model-00004-of-00007.safetensors",
157
+ "model.layers.layers.0.mixer.dt_proj.weight": "model-00005-of-00007.safetensors",
158
+ "model.layers.layers.2.mixer.dt_proj.weight": "model-00005-of-00007.safetensors",
159
+ "model.layers.layers.4.mixer.dt_proj.weight": "model-00005-of-00007.safetensors",
160
+ "model.layers.layers.6.mixer.dt_proj.weight": "model-00005-of-00007.safetensors",
161
+ "model.layers.layers.8.mixer.dt_proj.weight": "model-00005-of-00007.safetensors",
162
+ "model.layers.layers.10.mixer.dt_proj.weight": "model-00005-of-00007.safetensors",
163
+ "model.layers.layers.12.mixer.dt_proj.weight": "model-00005-of-00007.safetensors",
164
+ "model.layers.layers.14.mixer.dt_proj.weight": "model-00005-of-00007.safetensors",
165
+ "model.layers.layers.0.mlp.gate_up_proj.weight": "model-00001-of-00007.safetensors",
166
+ "model.layers.layers.0.mlp.down_proj.weight": "model-00001-of-00007.safetensors",
167
+ "model.layers.layers.1.mlp.gate_up_proj.weight": "model-00001-of-00007.safetensors",
168
+ "model.layers.layers.1.mlp.down_proj.weight": "model-00001-of-00007.safetensors",
169
+ "model.layers.layers.2.mlp.gate_up_proj.weight": "model-00001-of-00007.safetensors",
170
+ "model.layers.layers.2.mlp.down_proj.weight": "model-00001-of-00007.safetensors",
171
+ "model.layers.layers.3.mlp.gate_up_proj.weight": "model-00001-of-00007.safetensors",
172
+ "model.layers.layers.3.mlp.down_proj.weight": "model-00001-of-00007.safetensors",
173
+ "model.layers.layers.4.mlp.gate_up_proj.weight": "model-00001-of-00007.safetensors",
174
+ "model.layers.layers.4.mlp.down_proj.weight": "model-00001-of-00007.safetensors",
175
+ "model.layers.layers.5.mlp.gate_up_proj.weight": "model-00001-of-00007.safetensors",
176
+ "model.layers.layers.5.mlp.down_proj.weight": "model-00001-of-00007.safetensors",
177
+ "model.layers.layers.6.mlp.gate_up_proj.weight": "model-00001-of-00007.safetensors",
178
+ "model.layers.layers.6.mlp.down_proj.weight": "model-00001-of-00007.safetensors",
179
+ "model.layers.layers.7.mlp.gate_up_proj.weight": "model-00001-of-00007.safetensors",
180
+ "model.layers.layers.7.mlp.down_proj.weight": "model-00001-of-00007.safetensors",
181
+ "model.layers.layers.8.mlp.gate_up_proj.weight": "model-00001-of-00007.safetensors",
182
+ "model.layers.layers.8.mlp.down_proj.weight": "model-00001-of-00007.safetensors",
183
+ "model.layers.layers.9.mlp.gate_up_proj.weight": "model-00001-of-00007.safetensors",
184
+ "model.layers.layers.9.mlp.down_proj.weight": "model-00001-of-00007.safetensors",
185
+ "model.layers.layers.10.mlp.gate_up_proj.weight": "model-00001-of-00007.safetensors",
186
+ "model.layers.layers.10.mlp.down_proj.weight": "model-00001-of-00007.safetensors",
187
+ "model.layers.layers.11.mlp.gate_up_proj.weight": "model-00001-of-00007.safetensors",
188
+ "model.layers.layers.11.mlp.down_proj.weight": "model-00001-of-00007.safetensors",
189
+ "model.layers.layers.12.mlp.gate_up_proj.weight": "model-00001-of-00007.safetensors",
190
+ "model.layers.layers.12.mlp.down_proj.weight": "model-00001-of-00007.safetensors",
191
+ "model.layers.layers.13.mlp.gate_up_proj.weight": "model-00001-of-00007.safetensors",
192
+ "model.layers.layers.13.mlp.down_proj.weight": "model-00001-of-00007.safetensors",
193
+ "model.layers.layers.14.mlp.gate_up_proj.weight": "model-00001-of-00007.safetensors",
194
+ "model.layers.layers.14.mlp.down_proj.weight": "model-00001-of-00007.safetensors",
195
+ "model.layers.layers.15.mlp.gate_up_proj.weight": "model-00001-of-00007.safetensors",
196
+ "model.layers.layers.15.mlp.down_proj.weight": "model-00001-of-00007.safetensors",
197
+ "model.layers.layers.0.mixer.dt_bias": "model-00006-of-00007.safetensors",
198
+ "model.layers.layers.0.mixer.A_log": "model-00006-of-00007.safetensors",
199
+ "model.layers.layers.0.mixer.D": "model-00006-of-00007.safetensors",
200
+ "model.layers.layers.2.mixer.dt_bias": "model-00006-of-00007.safetensors",
201
+ "model.layers.layers.2.mixer.A_log": "model-00006-of-00007.safetensors",
202
+ "model.layers.layers.2.mixer.D": "model-00006-of-00007.safetensors",
203
+ "model.layers.layers.4.mixer.dt_bias": "model-00006-of-00007.safetensors",
204
+ "model.layers.layers.4.mixer.A_log": "model-00006-of-00007.safetensors",
205
+ "model.layers.layers.4.mixer.D": "model-00006-of-00007.safetensors",
206
+ "model.layers.layers.6.mixer.dt_bias": "model-00006-of-00007.safetensors",
207
+ "model.layers.layers.6.mixer.A_log": "model-00006-of-00007.safetensors",
208
+ "model.layers.layers.6.mixer.D": "model-00006-of-00007.safetensors",
209
+ "model.layers.layers.8.mixer.dt_bias": "model-00006-of-00007.safetensors",
210
+ "model.layers.layers.8.mixer.A_log": "model-00006-of-00007.safetensors",
211
+ "model.layers.layers.8.mixer.D": "model-00006-of-00007.safetensors",
212
+ "model.layers.layers.10.mixer.dt_bias": "model-00006-of-00007.safetensors",
213
+ "model.layers.layers.10.mixer.A_log": "model-00006-of-00007.safetensors",
214
+ "model.layers.layers.10.mixer.D": "model-00006-of-00007.safetensors",
215
+ "model.layers.layers.12.mixer.dt_bias": "model-00006-of-00007.safetensors",
216
+ "model.layers.layers.12.mixer.A_log": "model-00006-of-00007.safetensors",
217
+ "model.layers.layers.12.mixer.D": "model-00006-of-00007.safetensors",
218
+ "model.layers.layers.14.mixer.dt_bias": "model-00006-of-00007.safetensors",
219
+ "model.layers.layers.14.mixer.A_log": "model-00006-of-00007.safetensors",
220
+ "model.layers.layers.14.mixer.D": "model-00006-of-00007.safetensors"
221
+ },
222
+ "metadata": {}
223
+ }
modeling_plamo.py ADDED
@@ -0,0 +1,1616 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import enum
2
+ import math
3
+ import warnings
4
+ from typing import Any, Dict, List, Literal, NamedTuple, Optional, Tuple, Union
5
+
6
+ try:
7
+ # It is difficult to install mamba_ssm in login node because
8
+ # it requires GPU for installation
9
+ import mamba_ssm
10
+ except ModuleNotFoundError:
11
+ warnings.warn("mamba_ssm could not be imported", stacklevel=2)
12
+ try:
13
+ # It is difficult to install causal_conv1d in login node because
14
+ # it requires GPU for installation
15
+ import causal_conv1d.causal_conv1d_interface as causal_conv1d
16
+ except ModuleNotFoundError:
17
+ warnings.warn("causal_conv1d could not be imported", stacklevel=2)
18
+ import torch
19
+ from torch import nn
20
+ from torch.nn import functional as F
21
+ from transformers import PretrainedConfig, PreTrainedModel
22
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
23
+
24
+
25
+ def _is_first_token(mask: torch.Tensor) -> torch.Tensor:
26
+ assert mask.dtype == torch.bool
27
+ B, Nh, q_len, kv_len = mask.shape
28
+ mask = mask[:, :, :, -q_len:]
29
+ cont = q_len != kv_len
30
+ v = False if cont else True
31
+ out = torch.logical_not(torch.diagonal(mask, offset=-1, dim1=-2, dim2=-1).bool())
32
+ out = torch.cat(
33
+ [
34
+ torch.full(size=(B, Nh, 1), dtype=torch.bool, device=out.device, fill_value=v),
35
+ out,
36
+ ],
37
+ dim=-1,
38
+ )
39
+ return out
40
+
41
+
42
+ def _swiglu(h: torch.Tensor) -> torch.Tensor:
43
+ h0, h1 = h.chunk(2, dim=-1)
44
+ return torch.nn.functional.silu(h0) * h1
45
+
46
+
47
+ class RotaryEmbedding(torch.nn.Module):
48
+ def __init__(
49
+ self, dim: int, max_position_embeddings: int = 2048, base: int = 10000, device: Optional[torch.device] = None
50
+ ) -> None:
51
+ super().__init__()
52
+
53
+ self.dim = dim
54
+ self.max_position_embeddings = max_position_embeddings
55
+ self.base = base
56
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
57
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
58
+
59
+ # Build here to make `torch.jit.trace` work.
60
+ self._set_cos_sin_cache(
61
+ seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
62
+ )
63
+
64
+ def _set_cos_sin_cache(self, seq_len: int, device: Any, dtype: Any) -> None:
65
+ self.max_seq_len_cached = seq_len
66
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) # type: ignore
67
+
68
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
69
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
70
+ emb = torch.cat((freqs, freqs), dim=-1)
71
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
72
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
73
+
74
+ def forward(self, x: torch.Tensor, seq_len: int) -> Tuple[torch.Tensor, torch.Tensor]:
75
+ # x: [bs, num_attention_heads, seq_len, head_size]
76
+ if seq_len > self.max_seq_len_cached:
77
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
78
+
79
+ return (
80
+ self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), # type: ignore
81
+ self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), # type: ignore
82
+ )
83
+
84
+
85
+ def _rotate_half(x: torch.Tensor) -> torch.Tensor:
86
+ """Rotates half the hidden dims of the input."""
87
+ x1 = x[..., : x.shape[-1] // 2]
88
+ x2 = x[..., x.shape[-1] // 2 :]
89
+ return torch.cat((-x2, x1), dim=-1)
90
+
91
+
92
+ def _rotary_pos_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, position_ids: torch.Tensor) -> torch.Tensor:
93
+ # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
94
+ cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
95
+ sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
96
+ cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
97
+ sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
98
+ x_embed = (x * cos) + (_rotate_half(x) * sin)
99
+ return x_embed
100
+
101
+
102
+ class LinearType(str, enum.Enum):
103
+ Normal = "normal"
104
+ Fp8 = "fp8"
105
+ Fp8Retain = "fp8-retain"
106
+
107
+
108
+ class PlamoConfig(PretrainedConfig): # type: ignore
109
+ model_type: str = "plamo"
110
+
111
+ def __init__(
112
+ self,
113
+ hidden_size: int = 4096,
114
+ num_hidden_layers: int = 32,
115
+ rms_norm_eps: float = 1e-6,
116
+ tie_word_embeddings: bool = True,
117
+ # Attention
118
+ num_attention_heads: int = 32,
119
+ num_key_value_heads: int = 4,
120
+ hidden_size_per_head: int = 128,
121
+ max_position_embeddings: int = 2048,
122
+ attention_window_size: int = 2048,
123
+ full_attention_idx: list[int] | None = None,
124
+ # Mamba
125
+ mamba_d_state: int = 64,
126
+ mamba_d_conv: int = 4,
127
+ mamba_num_heads: int = 64,
128
+ mamba_step: int = 2,
129
+ mamba_chunk_size: int = 256,
130
+ mamba_enabled: bool = True,
131
+ # MLP
132
+ intermediate_size: int = 13312,
133
+ # Tokenizer
134
+ vocab_size: int = 32000,
135
+ tokenizer_class: str = "PlamoTokenizer",
136
+ pad_token_id: Optional[int] = None,
137
+ bos_token_id: int = 1,
138
+ eos_token_id: int = 2,
139
+ # Multimodal
140
+ image_token_id: Optional[int] = None,
141
+ image_feature_size: Optional[int] = None,
142
+ image_proj_type: Literal["linear", "mlp"] = "linear",
143
+ # FP8
144
+ linear_type: LinearType = LinearType.Normal,
145
+ fp8_accum_dtype: Optional[str] = None,
146
+ # Evaluation
147
+ eval_attention_n_bit: Optional[int] = None,
148
+ eval_mlp_n_bit: Optional[int] = None,
149
+ use_cache: bool = True,
150
+ **kwargs: Any,
151
+ ) -> None:
152
+ # max_position_embeddings is often used to determine the max length during inference,
153
+ # but samba should have extrapolation abilities
154
+ self.max_position_embeddings = max(10 * 1024 * 1024, max_position_embeddings)
155
+ self.hidden_size = hidden_size
156
+ self.rms_norm_eps = rms_norm_eps
157
+
158
+ self.num_hidden_layers = num_hidden_layers
159
+ self.num_attention_heads = num_attention_heads
160
+ self.hidden_size_per_head = hidden_size_per_head
161
+ self.num_key_value_heads = num_key_value_heads
162
+ self.attention_window_size = attention_window_size
163
+ self.full_attention_idx = full_attention_idx if full_attention_idx is not None else []
164
+
165
+ self.mamba_d_state = mamba_d_state
166
+ self.mamba_d_conv = mamba_d_conv
167
+ self.mamba_num_heads = mamba_num_heads
168
+ self.mamba_step = mamba_step
169
+ self.mamba_chunk_size = mamba_chunk_size
170
+ self.mamba_enabled = mamba_enabled
171
+
172
+ self.intermediate_size = intermediate_size
173
+
174
+ self.vocab_size = vocab_size
175
+
176
+ self.image_token_id = image_token_id
177
+ self.image_feature_size = image_feature_size
178
+ self.image_proj_type = image_proj_type
179
+
180
+ self.linear_type = linear_type
181
+ self.fp8_accum_dtype = fp8_accum_dtype
182
+
183
+ self.eval_attention_n_bit = eval_attention_n_bit
184
+ self.eval_mlp_n_bit = eval_mlp_n_bit
185
+ self.use_cache = use_cache
186
+
187
+ # fields for vLLM
188
+ self.sliding_window = attention_window_size
189
+
190
+ super().__init__(
191
+ tokenizer_class=tokenizer_class,
192
+ pad_token_id=pad_token_id,
193
+ bos_token_id=bos_token_id,
194
+ eos_token_id=eos_token_id,
195
+ tie_word_embeddings=tie_word_embeddings,
196
+ **kwargs,
197
+ )
198
+
199
+
200
+ class PlamoAttentionCache(torch.nn.Module):
201
+ def __init__(self, key: torch.Tensor, value: torch.Tensor) -> None:
202
+ super().__init__()
203
+ B, nh, L, c = key.shape
204
+ assert len(value.shape) == 4
205
+ assert value.shape[0] == B
206
+ assert value.shape[2] == L
207
+ self.register_parameter("key", torch.nn.Parameter(key, requires_grad=False))
208
+ self.register_parameter("value", torch.nn.Parameter(value, requires_grad=False))
209
+
210
+
211
+ class PlamoMambaCache(torch.nn.Module):
212
+ def __init__(self, conv_state: torch.Tensor, ssm_state: torch.Tensor) -> None:
213
+ super().__init__()
214
+ # conv_state: [B, C, d_conv]
215
+ # ssm_state: [B, nhead, nchanel_per_head, d_state]
216
+ assert len(conv_state.shape) == 3
217
+ assert len(ssm_state.shape) == 4
218
+ assert conv_state.shape[0] == ssm_state.shape[0]
219
+ self.register_parameter("conv_state", torch.nn.Parameter(conv_state, requires_grad=False))
220
+ self.register_parameter("ssm_state", torch.nn.Parameter(ssm_state, requires_grad=False))
221
+
222
+
223
+ PlamoLayerCache = PlamoAttentionCache | PlamoMambaCache
224
+
225
+
226
+ class PlamoCache(torch.nn.Module):
227
+ """
228
+ stores states of the model for fast decoding.
229
+ `transformers` uses `transformers.Cache` for this purpose, but the interface and variable names are
230
+ deeply dependent on Transformers architecture (e.g., `key_states`) and it is difficult to use
231
+ other architectures (e.g., Mamba).
232
+ This class provides a similar interface to `transformers.Cache`, but is designed to also handle
233
+ the state of Mamba properly.
234
+ """
235
+
236
+ def __init__(self, config: PlamoConfig) -> None:
237
+ super().__init__()
238
+ self.config = config
239
+ self.cache = torch.nn.ModuleList([None for _ in range(config.num_hidden_layers)]) # type: ignore
240
+
241
+ def append_kv(self, key: torch.Tensor, value: torch.Tensor, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor]:
242
+ c = self.cache[layer_idx]
243
+ assert isinstance(c, PlamoAttentionCache)
244
+
245
+ def _validate(cache: torch.Tensor, new_tensor: torch.Tensor) -> None:
246
+ assert len(cache.shape) == 4
247
+ assert len(new_tensor.shape) == 4
248
+ assert cache.shape[0] == new_tensor.shape[0]
249
+ assert cache.shape[1] == new_tensor.shape[1]
250
+ assert cache.shape[3] == new_tensor.shape[3]
251
+
252
+ _validate(c.key, key)
253
+ _validate(c.value, value)
254
+ assert key.shape[2] == value.shape[2]
255
+ return torch.cat([c.key, key], dim=2), torch.cat([c.value, value], dim=2)
256
+
257
+ def update_attention(
258
+ self, key_states: torch.Tensor, value_states: torch.Tensor, layer_idx: int
259
+ ) -> PlamoAttentionCache:
260
+ if self.cache[layer_idx] is None:
261
+ self.cache[layer_idx] = PlamoAttentionCache(key_states, value_states)
262
+ else:
263
+ full_attn = layer_idx in self.config.full_attention_idx
264
+ window_size = self.config.attention_window_size
265
+ c = self.cache[layer_idx]
266
+ assert isinstance(c, PlamoAttentionCache)
267
+ k, v = self.append_kv(key_states, value_states, layer_idx)
268
+ if full_attn:
269
+ c.key.data = k
270
+ c.value.data = v
271
+ else:
272
+ c.key.data = k[:, :, -window_size:, :]
273
+ c.value.data = v[:, :, -window_size:, :]
274
+ return self.cache[layer_idx] # type: ignore
275
+
276
+ def update_mamba(self, conv_state: torch.Tensor, ssm_state: torch.Tensor, layer_idx: int) -> PlamoMambaCache:
277
+ if self.cache[layer_idx] is None:
278
+ self.cache[layer_idx] = PlamoMambaCache(conv_state, ssm_state)
279
+ else:
280
+ c = self.cache[layer_idx]
281
+ assert isinstance(c, PlamoMambaCache)
282
+ assert c.conv_state.shape == conv_state.shape
283
+ assert c.ssm_state.shape == ssm_state.shape
284
+ c.conv_state.data = conv_state
285
+ c.ssm_state.data = ssm_state
286
+ return self.cache[layer_idx] # type: ignore
287
+
288
+ def __getitem__(self, layer_idx: int) -> PlamoLayerCache | None:
289
+ assert layer_idx < len(self.cache)
290
+ layer_cache = self.cache[layer_idx]
291
+ return layer_cache # type: ignore
292
+
293
+ def __len__(self) -> int:
294
+ return len(self.cache)
295
+
296
+ def get_seq_length(self, layer_idx: Optional[int] = None) -> int:
297
+ if layer_idx is not None:
298
+ c = self.cache[layer_idx]
299
+ assert isinstance(c, PlamoAttentionCache)
300
+ return c.key.shape[2] # type: ignore
301
+
302
+ sequence_length: int | None = None
303
+ for layer_cache in self.cache:
304
+ if isinstance(layer_cache, PlamoAttentionCache):
305
+ sequence_length = (
306
+ max(layer_cache.key.shape[2], sequence_length)
307
+ if sequence_length is not None
308
+ else layer_cache.key.shape[2]
309
+ )
310
+ assert sequence_length is not None
311
+ return sequence_length
312
+
313
+ def get_max_length(self) -> int | None:
314
+ return None
315
+
316
+ def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int:
317
+ """Given the sequence length of the new inputs, returns the usable length of the cache."""
318
+ # Cache without size limit -> all cache is usable
319
+ # Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache
320
+ # length, we will need to evict part of the cache (and thus not all cache is usable)
321
+ max_length = self.get_max_length()
322
+ previous_seq_length = self.get_seq_length(layer_idx)
323
+ if max_length is not None and previous_seq_length + new_seq_length > max_length:
324
+ return max_length - new_seq_length
325
+ return previous_seq_length
326
+
327
+ def reorder_cache(self, beam_idx: torch.Tensor) -> None:
328
+ def _mamba(cache: PlamoMambaCache) -> PlamoMambaCache:
329
+ return PlamoMambaCache(
330
+ conv_state=cache.conv_state.index_select(0, beam_idx),
331
+ ssm_state=cache.ssm_state.index_select(0, beam_idx),
332
+ )
333
+
334
+ def _attention(cache: PlamoAttentionCache) -> PlamoAttentionCache:
335
+ return PlamoAttentionCache(
336
+ key=cache.key.index_select(0, beam_idx),
337
+ value=cache.value.index_select(0, beam_idx),
338
+ )
339
+
340
+ for i in range(len(self.cache)):
341
+ if self.cache[i] is None:
342
+ continue
343
+ layer_cache = self.cache[i]
344
+ if isinstance(layer_cache, PlamoMambaCache):
345
+ self.cache[i] = _mamba(layer_cache)
346
+ else:
347
+ assert isinstance(layer_cache, PlamoAttentionCache)
348
+ self.cache[i] = _attention(layer_cache)
349
+
350
+ @property
351
+ def seen_tokens(self) -> int | None:
352
+ return None
353
+
354
+
355
+ class DecoderInput(NamedTuple):
356
+ hidden_states: torch.Tensor
357
+ attention_mask: Optional[torch.Tensor] = None
358
+ past_states: Optional[PlamoCache] = None
359
+ output_hidden_states: Optional[bool] = False
360
+ output_attentions: Optional[bool] = False
361
+ gradient_checkpointing: bool = False
362
+ input_ids: Optional[torch.Tensor] = None
363
+
364
+
365
+ class DecoderOutput(NamedTuple):
366
+ hidden_states: torch.Tensor
367
+ all_hidden_states: Optional[Tuple[torch.Tensor, ...]]
368
+ all_self_attns: Optional[Tuple[torch.Tensor, ...]]
369
+
370
+
371
+ # Copied from transformers.models.bart.modeling_bart._make_causal_mask
372
+ def _make_causal_mask(
373
+ input_ids_shape: Tuple[int, int], dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
374
+ ) -> torch.Tensor:
375
+ """
376
+ Make causal mask used for bi-directional self-attention.
377
+ """
378
+ bsz, tgt_len = input_ids_shape
379
+ mask = torch.full((tgt_len, tgt_len), float("-inf"), device=device)
380
+ mask_cond = torch.arange(mask.size(-1), device=device)
381
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
382
+ mask = mask.to(dtype)
383
+
384
+ if past_key_values_length > 0:
385
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
386
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
387
+
388
+
389
+ # Copied from transformers.models.bart.modeling_bart._expand_mask
390
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None) -> torch.Tensor:
391
+ """
392
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
393
+ """
394
+ bsz, src_len = mask.size()
395
+ tgt_len = tgt_len if tgt_len is not None else src_len
396
+
397
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
398
+
399
+ inverted_mask = 1.0 - expanded_mask
400
+
401
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), float("-inf")) # type: ignore
402
+
403
+
404
+ def _rms_norm(
405
+ hidden_states: torch.Tensor, weight: Optional[torch.Tensor], eps: float, offset: float = 1.0
406
+ ) -> torch.Tensor:
407
+ input_dtype = hidden_states.dtype
408
+ hidden_states = hidden_states.to(torch.float32)
409
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
410
+ hidden_states = hidden_states * torch.rsqrt(variance + eps)
411
+ hidden_states = hidden_states.to(input_dtype)
412
+ if weight is not None:
413
+ hidden_states = (offset + weight) * hidden_states
414
+ return hidden_states
415
+
416
+
417
+ class RMSNorm(nn.Module):
418
+ def __init__(
419
+ self,
420
+ hidden_size: int,
421
+ eps: float = 1e-6,
422
+ offset: float = 1.0,
423
+ device: Optional[Union[torch.device, str]] = None,
424
+ ) -> None:
425
+ super().__init__()
426
+ self.weight = nn.Parameter(torch.zeros(hidden_size, device=device))
427
+ self.variance_epsilon = eps
428
+ self.offset = offset
429
+
430
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
431
+ return _rms_norm(hidden_states, self.weight, self.variance_epsilon, offset=self.offset)
432
+
433
+
434
+ def get_initial_dt_bias(num_heads: int) -> torch.Tensor:
435
+ dt_min = 0.001
436
+ dt_max = 0.1
437
+ dt = torch.exp(torch.rand(num_heads) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min))
438
+ dt = torch.clamp(dt, 1e-4)
439
+ inv_dt = dt + torch.log(-torch.expm1(-dt))
440
+ return inv_dt
441
+
442
+
443
+ def get_initial_A(num_heads: int) -> torch.Tensor:
444
+ A = torch.arange(1, num_heads + 1, dtype=torch.float32)
445
+ return torch.log(A)
446
+
447
+
448
+ def _bf16_supported_in_triton() -> bool:
449
+ # newer torch (2.2.0 and later?) supports bfloat16 even when using Voltas
450
+ # but triton cannot compile bf16 kernels for Volta
451
+ major, _ = torch.cuda.get_device_capability()
452
+ return major >= 8
453
+
454
+
455
+ def _get_trition_dtype(dtype: torch.dtype) -> torch.dtype:
456
+ if dtype != torch.bfloat16:
457
+ return dtype
458
+ if _bf16_supported_in_triton():
459
+ return dtype
460
+ return torch.float32
461
+
462
+
463
+ def ssd_update_state(
464
+ ssm_state: torch.Tensor,
465
+ x: torch.Tensor,
466
+ dt: torch.Tensor,
467
+ A: torch.Tensor,
468
+ B: torch.Tensor,
469
+ C: torch.Tensor,
470
+ D: torch.Tensor,
471
+ z: torch.Tensor,
472
+ dt_bias: torch.Tensor,
473
+ dt_softplus: bool,
474
+ ) -> torch.Tensor:
475
+ assert ssm_state.dtype == torch.float32
476
+ if dt.is_cuda:
477
+ dtype = _get_trition_dtype(x.dtype)
478
+ else:
479
+ dtype = x.dtype
480
+ if dt.is_cuda:
481
+ f = mamba_ssm.ops.triton.selective_state_update.selective_state_update
482
+ else:
483
+ f = mamba_ssm.ops.triton.selective_state_update.selective_state_update_ref
484
+
485
+ hidden_size_per_head = x.shape[-1]
486
+ d_state = B.shape[-1]
487
+ A = A[:, None, None].expand(-1, hidden_size_per_head, d_state).float()
488
+ dt = dt[..., None].expand(-1, -1, hidden_size_per_head)
489
+ dt_bias = dt_bias[:, None].expand(-1, hidden_size_per_head)
490
+ D = D[:, None].expand(-1, hidden_size_per_head)
491
+ assert ssm_state.dtype == torch.float32
492
+ out = f(
493
+ ssm_state,
494
+ x.to(dtype),
495
+ dt.to(dtype),
496
+ A.float(),
497
+ B.to(dtype),
498
+ C.to(dtype),
499
+ D.float(),
500
+ z.to(dtype),
501
+ dt_bias.float(),
502
+ dt_softplus=dt_softplus,
503
+ )
504
+ return out[:, None] # type: ignore
505
+
506
+
507
+ def _ssd_chunk_scan_combined_naive(
508
+ x: torch.Tensor,
509
+ dt: torch.Tensor,
510
+ A: torch.Tensor,
511
+ B: torch.Tensor,
512
+ C: torch.Tensor,
513
+ D: torch.Tensor,
514
+ z: torch.Tensor,
515
+ dt_bias: torch.Tensor,
516
+ dt_softplus: bool,
517
+ seq_idx: torch.Tensor | None,
518
+ ssm_state: torch.Tensor,
519
+ ) -> tuple[torch.Tensor, torch.Tensor]:
520
+ assert ssm_state.dtype == torch.float32
521
+ length = x.shape[1]
522
+ ys = []
523
+ for i in range(length):
524
+ if i != 0 and seq_idx is not None:
525
+ ssm_state = torch.where(
526
+ (seq_idx[:, i - 1] != seq_idx[:, i])[:, None, None, None],
527
+ torch.zeros_like(ssm_state),
528
+ ssm_state,
529
+ )
530
+ y = ssd_update_state(
531
+ ssm_state,
532
+ x[:, i],
533
+ dt[:, i],
534
+ A,
535
+ B[:, i],
536
+ C[:, i],
537
+ D,
538
+ z=z[:, i],
539
+ dt_bias=dt_bias,
540
+ dt_softplus=dt_softplus,
541
+ )
542
+ ys.append(y)
543
+ return torch.cat(ys, dim=1), ssm_state
544
+
545
+
546
+ def ssd_chunk_scan_combined(
547
+ x: torch.Tensor,
548
+ dt: torch.Tensor,
549
+ A: torch.Tensor,
550
+ B: torch.Tensor,
551
+ C: torch.Tensor,
552
+ chunk_size: int,
553
+ D: torch.Tensor,
554
+ z: torch.Tensor,
555
+ dt_bias: torch.Tensor,
556
+ dt_softplus: bool,
557
+ return_final_states: bool,
558
+ seq_idx: torch.Tensor | None,
559
+ ssm_state: torch.Tensor | None,
560
+ ) -> tuple[torch.Tensor, torch.Tensor] | torch.Tensor:
561
+ if seq_idx is not None:
562
+ assert seq_idx.dtype == torch.int32
563
+ assert ssm_state is None
564
+ assert not return_final_states
565
+ if ssm_state is not None:
566
+ assert ssm_state.dtype == torch.float32
567
+ assert seq_idx is None
568
+
569
+ length = x.shape[1]
570
+
571
+ """
572
+ state will be updates by following:
573
+ ```
574
+ dt = softplus(dt)
575
+ dA = exp(dt * A)
576
+ state_next = state * dA + dB * x
577
+ ```
578
+
579
+ To avoid updating state, we set dt to -inf and x to 0
580
+ because `softplus(-inf) = 0` and `exp(0) = 1`
581
+ """
582
+ if dt.is_cuda:
583
+ pad = (chunk_size - length % chunk_size) % chunk_size
584
+ x = torch.nn.functional.pad(x, pad=[0, 0, 0, 0, pad, 0], value=0.0)
585
+ dt = torch.nn.functional.pad(dt, pad=[0, 0, pad, 0], value=float("-inf"))
586
+ B = torch.nn.functional.pad(B, pad=[0, 0, 0, 0, pad, 0], value=0.0)
587
+ C = torch.nn.functional.pad(C, pad=[0, 0, 0, 0, pad, 0], value=0.0)
588
+ z = torch.nn.functional.pad(z, pad=[0, 0, 0, 0, pad, 0], value=0.0)
589
+ if seq_idx is not None:
590
+ seq_idx = torch.nn.functional.pad(seq_idx, pad=[pad, 0], value=0)
591
+
592
+ length = x.shape[1]
593
+ assert length % chunk_size == 0, (length, chunk_size)
594
+
595
+ dtype = _get_trition_dtype(x.dtype)
596
+ out = mamba_ssm.ops.triton.ssd_combined.mamba_chunk_scan_combined( # type: ignore
597
+ x.to(dtype),
598
+ dt.to(dtype),
599
+ A.float(),
600
+ B.to(dtype),
601
+ C.to(dtype),
602
+ chunk_size,
603
+ D=D.float(),
604
+ z=z.to(dtype),
605
+ initial_states=ssm_state,
606
+ dt_bias=dt_bias.float(),
607
+ dt_softplus=dt_softplus,
608
+ seq_idx=seq_idx,
609
+ return_final_states=return_final_states,
610
+ )
611
+ if return_final_states:
612
+ return out[0][:, pad:], out[1]
613
+ else:
614
+ assert isinstance(out, torch.Tensor)
615
+ return out[:, pad:]
616
+ else:
617
+ if ssm_state is None:
618
+ bsize, _, num_heads, channel = x.shape
619
+ state = B.shape[-1]
620
+ ssm_state = torch.zeros(bsize, num_heads, channel, state, dtype=torch.float32, device=x.device)
621
+ tmp = _ssd_chunk_scan_combined_naive(
622
+ x, dt, A, B, C, D, z=z, dt_bias=dt_bias, dt_softplus=dt_softplus, seq_idx=seq_idx, ssm_state=ssm_state
623
+ )
624
+ if return_final_states:
625
+ return tmp
626
+ else:
627
+ return tmp[0]
628
+
629
+
630
+ def _causal_conv1d(
631
+ conv_state: torch.Tensor | None, weight: torch.Tensor, x: torch.Tensor, seq_idx: torch.Tensor | None
632
+ ) -> tuple[torch.Tensor, torch.Tensor | None]:
633
+ dtype = x.dtype
634
+ if conv_state is not None:
635
+ dtype = conv_state.dtype
636
+ assert seq_idx is None
637
+ if seq_idx is not None:
638
+ assert seq_idx.dtype == torch.int32
639
+ assert conv_state is None
640
+ weight = weight.to(dtype)
641
+ x = x.to(dtype)
642
+
643
+ return_final_states = conv_state is not None
644
+ if weight.is_cuda:
645
+ if x.stride(1) != 1:
646
+ # to channel-last format
647
+ x = x.transpose(-1, -2).contiguous().transpose(-1, -2)
648
+ if conv_state is not None:
649
+ if conv_state.stride(1) != 1:
650
+ # to channel-last format
651
+ conv_state = conv_state.transpose(-1, -2).contiguous().transpose(-1, -2)
652
+ tmp = causal_conv1d.causal_conv1d_fn(
653
+ x=x,
654
+ weight=weight[:, 0, :],
655
+ initial_states=conv_state,
656
+ return_final_states=conv_state is not None,
657
+ activation="silu",
658
+ seq_idx=seq_idx,
659
+ )
660
+ if conv_state is not None:
661
+ x, conv_state = tmp
662
+ else:
663
+ x = tmp
664
+ else:
665
+ if conv_state is None:
666
+ bsize = x.shape[0]
667
+ dim = weight.shape[0]
668
+ d_conv = weight.shape[-1]
669
+ conv_state = torch.zeros(bsize, dim, d_conv - 1, dtype=x.dtype, device=x.device)
670
+ length = x.shape[-1]
671
+ out = torch.zeros_like(x)
672
+ for i in range(length):
673
+ if i != 0 and seq_idx is not None:
674
+ conv_state = torch.where(
675
+ (seq_idx[:, i - 1] != seq_idx[:, i])[:, None, None],
676
+ torch.zeros_like(conv_state),
677
+ conv_state,
678
+ )
679
+ out[:, :, i : i + 1], conv_state = _causal_conv1d_update(conv_state, weight, x[:, :, i : i + 1])
680
+ x = out
681
+ if return_final_states:
682
+ return x, conv_state
683
+ else:
684
+ return x, None
685
+
686
+
687
+ def _causal_conv1d_update(
688
+ conv_state: torch.Tensor, weight: torch.Tensor, xBC: torch.Tensor
689
+ ) -> tuple[torch.Tensor, torch.Tensor]:
690
+ dtype = conv_state.dtype
691
+ xBC = xBC.to(dtype)
692
+ weight = weight.to(dtype)
693
+ if conv_state.is_cuda:
694
+ x = causal_conv1d.causal_conv1d_update(
695
+ x=xBC,
696
+ conv_state=conv_state,
697
+ weight=weight[:, 0, :],
698
+ activation="silu",
699
+ )
700
+ return x, conv_state
701
+ else:
702
+ x = causal_conv1d.causal_conv1d_update_ref(
703
+ x=xBC,
704
+ conv_state=conv_state,
705
+ weight=weight[:, 0, :],
706
+ activation="silu",
707
+ )
708
+ return x, conv_state
709
+
710
+
711
+ class Mamba(torch.nn.Module):
712
+ def __init__(self, config: PlamoConfig, layer_idx: int) -> None:
713
+ super().__init__()
714
+ self.config = config
715
+ self.layer_idx = layer_idx
716
+ self.hidden_size = config.hidden_size
717
+ self.d_state = config.mamba_d_state
718
+ self.d_conv = config.mamba_d_conv
719
+ self.chunk_size = config.mamba_chunk_size
720
+ self.num_heads = config.mamba_num_heads
721
+ # TODO add mamba_hidden_size_per_head config (?)
722
+ self.hidden_size_per_head = config.hidden_size_per_head
723
+
724
+ self.intermediate_size = self.num_heads * self.hidden_size_per_head
725
+
726
+ self.in_proj = torch.nn.Linear(self.hidden_size, 2 * self.intermediate_size, bias=False)
727
+ self.conv1d = torch.nn.Conv1d(
728
+ in_channels=self.intermediate_size,
729
+ out_channels=self.intermediate_size,
730
+ bias=False, # TODO the original implementation uses bias
731
+ kernel_size=self.d_conv,
732
+ groups=self.intermediate_size,
733
+ padding=0,
734
+ )
735
+ self.dt_dim = max(64, self.hidden_size // 16)
736
+ # Notes:
737
+ # Mamba2 removes this linear projection for simplicity (Figure 6 in the paper),
738
+ # but it may degrade the ability of content-length extrapolation.
739
+ self.bcdt_proj = torch.nn.Linear(
740
+ self.intermediate_size,
741
+ self.dt_dim + 2 * self.d_state,
742
+ bias=False,
743
+ )
744
+ self.dt_proj = torch.nn.Linear(self.dt_dim, self.num_heads, bias=False)
745
+
746
+ self.dt_bias = torch.nn.Parameter(get_initial_dt_bias(self.num_heads))
747
+ self.A_log = torch.nn.Parameter(get_initial_A(self.num_heads))
748
+ self.D = torch.nn.Parameter(torch.ones(self.num_heads))
749
+
750
+ # TODO norm weight before gating like Mamba2
751
+ self.dt_norm_weight = torch.nn.Parameter(torch.ones(self.dt_dim))
752
+ self.B_norm_weight = torch.nn.Parameter(torch.ones(self.d_state))
753
+ self.C_norm_weight = torch.nn.Parameter(torch.ones(self.d_state))
754
+
755
+ self.out_proj = torch.nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
756
+
757
+ def _no_weight_decay_param_names(self) -> set[str]:
758
+ return set(["D", "dt_bias", "A_log"])
759
+
760
+ def forward(
761
+ self,
762
+ hidden_states: torch.Tensor,
763
+ attention_mask: Optional[torch.Tensor] = None,
764
+ past_states: Optional[PlamoCache] = None,
765
+ ) -> Tuple[torch.Tensor, Optional[PlamoCache]]:
766
+ bsize, length, _ = hidden_states.shape
767
+ is_update = length == 1 and past_states is not None
768
+
769
+ bool_mask: torch.Tensor | None = None
770
+ seq_idx: torch.Tensor | None = None
771
+ if attention_mask is not None:
772
+ if len(attention_mask.shape) == 2:
773
+ attention_mask = attention_mask[None, None].expand(bsize, 1, -1, -1)
774
+ assert len(attention_mask.shape) == 4
775
+
776
+ if past_states is None:
777
+ # TODO: support seq_idx with cache
778
+ bool_mask_4d = attention_mask == 0
779
+ is_first_token = _is_first_token(bool_mask_4d)[:, 0, :]
780
+ seq_idx = torch.cumsum(is_first_token, dim=-1) - 1
781
+ seq_idx = seq_idx.to(torch.int32)
782
+
783
+ # `generate` function creates attention mask that contains past tokens,
784
+ # but mamba does not use them
785
+ attention_mask = attention_mask[:, 0, -length:, -length:]
786
+ bool_mask = torch.diagonal(attention_mask, dim1=-2, dim2=-1) == 0
787
+
788
+ conv_state: torch.Tensor | None
789
+ ssm_state: torch.Tensor | None
790
+ if past_states is None:
791
+ conv_state = None
792
+ ssm_state = None
793
+ elif past_states[self.layer_idx] is None:
794
+ conv_state = torch.zeros(
795
+ bsize, self.intermediate_size, self.d_conv - 1, dtype=hidden_states.dtype, device=hidden_states.device
796
+ )
797
+ ssm_state = torch.zeros(
798
+ bsize,
799
+ self.num_heads,
800
+ self.hidden_size_per_head,
801
+ self.d_state,
802
+ dtype=torch.float32,
803
+ device=hidden_states.device,
804
+ )
805
+ else:
806
+ c = past_states[self.layer_idx]
807
+ assert isinstance(c, PlamoMambaCache)
808
+ conv_state = c.conv_state
809
+ ssm_state = c.ssm_state
810
+
811
+ zx = self.in_proj(hidden_states)
812
+ zx = zx.reshape(bsize, length, self.num_heads, -1)
813
+ # z: (bsize, length, num_heads, hidden_size_per_head)
814
+ # x: (bsize, length, num_heads, hidden_size_per_head)
815
+ z, x = torch.split(zx, [self.hidden_size_per_head, self.hidden_size_per_head], dim=-1)
816
+
817
+ # conv
818
+ x = x.reshape(bsize, length, -1).transpose(1, 2) # (bsize, intermediate_size, length)
819
+ if bool_mask is not None:
820
+ x = torch.where(bool_mask[:, None, :], x, 0.0)
821
+ if is_update:
822
+ assert conv_state is not None
823
+ x, conv_state = _causal_conv1d_update(conv_state, self.conv1d.weight, x)
824
+ else:
825
+ x, conv_state = _causal_conv1d(conv_state, self.conv1d.weight, x, seq_idx=seq_idx)
826
+ x = x.to(dtype=hidden_states.dtype)
827
+ x = x.transpose(1, 2) # (bsize, length, intermediate_size)
828
+ x = x.reshape(bsize, length, -1)
829
+ # x: (bsize, length, num_heads, hidden_size_per_head)
830
+ # B: (bsize, length, 1, d_state)
831
+ # C: (bsize, length, 1, d_state)
832
+ # dt: (bsize, length, dt_dim)
833
+ BCdt = self.bcdt_proj(x)
834
+ x = x.reshape(bsize, length, self.num_heads, -1)
835
+ B, C, dt = torch.split(BCdt, [self.d_state, self.d_state, self.dt_dim], dim=-1)
836
+ B = B[:, :, None, :]
837
+ C = C[:, :, None, :]
838
+
839
+ A = -torch.exp(self.A_log.float()) # (num_heads,)
840
+ dt = _rms_norm(dt, None, self.config.rms_norm_eps) * self.dt_norm_weight[None, None, :]
841
+ B = _rms_norm(B, None, self.config.rms_norm_eps) * self.B_norm_weight[None, None, None, :]
842
+ C = _rms_norm(C, None, self.config.rms_norm_eps) * self.C_norm_weight[None, None, None, :]
843
+
844
+ # (bsize, length, num_heads, 1)
845
+ dt = self.dt_proj(dt)[..., None]
846
+
847
+ # TODO it may not be required
848
+ B = B.expand(-1, -1, self.num_heads, -1)
849
+ C = C.expand(-1, -1, self.num_heads, -1)
850
+
851
+ if bool_mask is not None:
852
+ """
853
+ state will be updates by following:
854
+ ```
855
+ dt = softplus(dt)
856
+ dA = exp(dt * A)
857
+ state_next = state * dA + dB * x
858
+ ```
859
+
860
+ To avoid updating state, we set dt to -inf and x to 0
861
+ because `softplus(-inf) = 0` and `exp(0) = 1`
862
+ """
863
+ dt = torch.where(bool_mask[:, :, None, None], dt, float("-inf"))
864
+ x = torch.where(bool_mask[:, :, None, None], x, 0.0)
865
+
866
+ # ssm
867
+ if is_update:
868
+ assert ssm_state is not None
869
+ out = ssd_update_state(
870
+ ssm_state,
871
+ x[:, 0],
872
+ dt[:, 0].reshape(bsize, -1),
873
+ A,
874
+ B[:, 0],
875
+ C[:, 0],
876
+ D=self.D,
877
+ z=z[:, 0],
878
+ dt_bias=self.dt_bias,
879
+ dt_softplus=True,
880
+ )
881
+ else:
882
+ tmp = ssd_chunk_scan_combined(
883
+ x,
884
+ dt.reshape(bsize, length, -1),
885
+ A,
886
+ B,
887
+ C,
888
+ self.chunk_size,
889
+ D=self.D,
890
+ z=z,
891
+ dt_bias=self.dt_bias,
892
+ dt_softplus=True,
893
+ return_final_states=past_states is not None,
894
+ seq_idx=seq_idx,
895
+ ssm_state=ssm_state,
896
+ )
897
+ if past_states is not None:
898
+ out, ssm_state = tmp
899
+ else:
900
+ assert isinstance(tmp, torch.Tensor)
901
+ out = tmp
902
+
903
+ y = self.out_proj(out.reshape(bsize, length, -1))
904
+
905
+ if past_states is not None:
906
+ assert ssm_state is not None
907
+ assert conv_state is not None
908
+ past_states.update_mamba(conv_state, ssm_state, self.layer_idx)
909
+
910
+ return y, past_states
911
+
912
+
913
+ def swa_mask(q_len: int, kv_len: int, device: torch.device, window_size: int) -> torch.Tensor:
914
+ max_len = max(q_len, kv_len)
915
+ mask = (
916
+ torch.ones(max_len, max_len, dtype=torch.bool, device=device)
917
+ .triu(diagonal=-window_size)
918
+ .tril(diagonal=window_size)
919
+ )
920
+ return mask[-q_len:, -kv_len:]
921
+
922
+
923
+ class Attention(torch.nn.Module):
924
+ def __init__(self, config: PlamoConfig, layer_idx: int) -> None:
925
+ super().__init__()
926
+ self.config = config
927
+ self.layer_idx = layer_idx
928
+ self.hidden_size = config.hidden_size
929
+ head_dim = config.hidden_size_per_head
930
+ self.max_position_embeddings = config.max_position_embeddings
931
+
932
+ self.q_num_heads = config.num_attention_heads
933
+ self.qk_dim = self.v_dim = head_dim
934
+ self.k_num_heads = self.v_num_heads = config.num_key_value_heads
935
+ assert self.q_num_heads % self.k_num_heads == 0
936
+ self.n_group = self.q_num_heads // self.k_num_heads
937
+
938
+ self.q_proj_dim = self.q_num_heads * self.qk_dim
939
+ self.k_proj_dim = self.k_num_heads * self.qk_dim
940
+ self.v_proj_dim = self.k_num_heads * self.v_dim
941
+ self.qkv_proj = nn.Linear(self.hidden_size, self.q_proj_dim + self.k_proj_dim + self.v_proj_dim, bias=False)
942
+ self.o_proj = nn.Linear(self.q_num_heads * self.v_dim, self.hidden_size, bias=False)
943
+
944
+ self.q_weight = torch.nn.Parameter(torch.ones((self.q_num_heads, self.qk_dim)))
945
+ self.k_weight = torch.nn.Parameter(torch.ones((self.k_num_heads, self.qk_dim)))
946
+
947
+ self.rotary_emb = RotaryEmbedding(self.qk_dim, max_position_embeddings=self.config.attention_window_size)
948
+
949
+ def forward(
950
+ self,
951
+ hidden_states: torch.Tensor,
952
+ attention_mask: Optional[torch.Tensor] = None,
953
+ past_states: Optional[PlamoCache] = None,
954
+ output_attentions: bool = False,
955
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[PlamoCache]]:
956
+ bsz, q_len, _ = hidden_states.size()
957
+
958
+ qkv = self.qkv_proj(hidden_states)
959
+ query_states, key_states, value_states = torch.split(
960
+ qkv, [self.q_proj_dim, self.k_proj_dim, self.v_proj_dim], dim=-1
961
+ )
962
+ query_states = query_states.view(bsz, q_len, self.q_num_heads, self.qk_dim).transpose(1, 2)
963
+ key_states = key_states.view(bsz, q_len, self.k_num_heads, self.qk_dim).transpose(1, 2)
964
+ value_states = value_states.view(bsz, q_len, self.v_num_heads, self.v_dim).transpose(1, 2)
965
+
966
+ attn_dtype = query_states.dtype
967
+
968
+ query_states = _rms_norm(query_states, None, 1e-6) * self.q_weight[None, :, None]
969
+ key_states = _rms_norm(key_states, None, 1e-6) * self.k_weight[None, :, None]
970
+
971
+ if past_states is not None and past_states[self.layer_idx] is None:
972
+ bsz, nhead_k, _, c_k = key_states.shape
973
+ _, nhead_v, _, c_v = value_states.shape
974
+ past_states.update_attention(
975
+ torch.zeros((bsz, nhead_k, 0, c_k), dtype=key_states.dtype, device=key_states.device),
976
+ torch.zeros((bsz, nhead_v, 0, c_v), dtype=value_states.dtype, device=value_states.device),
977
+ self.layer_idx,
978
+ )
979
+
980
+ if past_states is not None:
981
+ # reuse k, v, self_attention
982
+ key_states_new = key_states
983
+ value_states_new = value_states
984
+ key_states, value_states = past_states.append_kv(key_states, value_states, self.layer_idx) # type: ignore
985
+ past_states.update_attention(key_states_new, value_states_new, self.layer_idx)
986
+
987
+ kv_seq_len = key_states.shape[-2]
988
+ device = hidden_states.device
989
+ position_ids = torch.arange(kv_seq_len, dtype=torch.long, device=device)[None]
990
+ q_position_ids = position_ids[:, -query_states.shape[2] :]
991
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
992
+ query_states = _rotary_pos_emb(query_states, cos, sin, q_position_ids)
993
+ key_states = _rotary_pos_emb(key_states, cos, sin, position_ids)
994
+ # [bsz, nh, t, hd]
995
+
996
+ def _expand_kv(t: torch.Tensor, repeat: int, target: int) -> torch.Tensor:
997
+ t = torch.repeat_interleave(t, repeat, dim=1)
998
+ return t[:, :target]
999
+
1000
+ # expand shared kv
1001
+ assert self.k_num_heads == self.v_num_heads
1002
+ key_states = _expand_kv(key_states, self.n_group, self.q_num_heads)
1003
+ value_states = _expand_kv(value_states, self.n_group, self.q_num_heads)
1004
+
1005
+ full_attn = self.layer_idx in self.config.full_attention_idx
1006
+
1007
+ query_states = query_states.to(attn_dtype)
1008
+ key_states = key_states.to(attn_dtype)
1009
+ value_states = value_states.to(attn_dtype)
1010
+ if attention_mask is not None and attention_mask.dtype != torch.bool:
1011
+ attention_mask = attention_mask.to(attn_dtype)
1012
+ if attention_mask is None:
1013
+ if not full_attn:
1014
+ assert key_states.shape[2] <= self.config.attention_window_size + 1
1015
+ attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, is_causal=True)
1016
+ else:
1017
+ if attention_mask.dtype == torch.bool:
1018
+ attention_mask = torch.where(attention_mask, torch.tensor(0.0, dtype=torch.float), float("-inf"))
1019
+ if len(attention_mask.shape) == 2:
1020
+ attention_mask = attention_mask[None, None]
1021
+ assert len(attention_mask.shape) == 4
1022
+
1023
+ if not full_attn:
1024
+ m_swa = swa_mask(
1025
+ query_states.shape[2], key_states.shape[2], query_states.device, self.config.attention_window_size
1026
+ )
1027
+ # `generate` function creates attention mask that does not consider sliding window
1028
+ m_swa = m_swa[None, None]
1029
+ attention_mask = attention_mask[:, :, -query_states.shape[2] :, -key_states.shape[2] :]
1030
+ attention_mask = torch.where(m_swa, attention_mask, float("-inf"))
1031
+
1032
+ # like AttentionMaskConverter._unmask_unattended in huggingface.transfoermers,
1033
+ # we need to attend to all tokens in masked rows for `scaled_dot_product_attention`
1034
+ bool_mask = torch.logical_not(torch.isneginf(attention_mask))
1035
+ valid_tokens = torch.sum(bool_mask, dim=-1).bool() # (..., q_len)
1036
+ attention_mask = torch.where(valid_tokens[..., None], attention_mask, float(0.0))
1037
+ attn_output = F.scaled_dot_product_attention(
1038
+ query_states, key_states, value_states, attn_mask=attention_mask
1039
+ )
1040
+
1041
+ attn_output = attn_output.transpose(1, 2)
1042
+
1043
+ attn_output = attn_output.reshape(bsz, q_len, self.q_num_heads * self.v_dim)
1044
+ attn_output = self.o_proj(attn_output)
1045
+
1046
+ if not output_attentions:
1047
+ attn_weights = None
1048
+
1049
+ return attn_output, attn_weights, past_states
1050
+
1051
+
1052
+ class MLP(nn.Module):
1053
+ def __init__(self, config: PlamoConfig) -> None:
1054
+ super().__init__()
1055
+ self.config = config
1056
+ self.hidden_size = config.hidden_size
1057
+ self.intermediate_size = config.intermediate_size
1058
+ self.gate_up_proj = torch.nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=False)
1059
+ self.down_proj = torch.nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
1060
+
1061
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
1062
+ h = self.gate_up_proj(x)
1063
+ h = _swiglu(h)
1064
+ return self.down_proj(h) # type: ignore
1065
+
1066
+
1067
+ class PlamoDecoderLayer(torch.nn.Module):
1068
+ def __init__(self, config: PlamoConfig, is_mamba: bool, layer_idx: int) -> None:
1069
+ super().__init__()
1070
+ self.config = config
1071
+ self.hidden_size = config.hidden_size
1072
+ self.is_mamba = is_mamba
1073
+ self.mixer: torch.nn.Module
1074
+ if is_mamba:
1075
+ self.mixer = Mamba(config, layer_idx)
1076
+ else:
1077
+ self.mixer = Attention(config, layer_idx)
1078
+ self.mlp = MLP(config)
1079
+ """
1080
+ Notes: The model performance was degraded when setting all offsets to 1.
1081
+ """
1082
+ self.pre_mixer_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, offset=1.0)
1083
+ self.post_mixer_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, offset=1.0 / 5)
1084
+ self.pre_mlp_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, offset=1.0)
1085
+ self.post_mlp_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, offset=1.0 / (5**1.5))
1086
+
1087
+ def forward(
1088
+ self,
1089
+ hidden_states: torch.Tensor,
1090
+ attention_mask: Optional[torch.Tensor] = None,
1091
+ past_state: Optional[PlamoCache] = None,
1092
+ output_attentions: Optional[bool] = False,
1093
+ ) -> Tuple[Any, ...]:
1094
+ # from LlamaDecoder
1095
+ residual = hidden_states
1096
+ hidden_states = self.pre_mixer_norm(hidden_states)
1097
+
1098
+ # Self Attention
1099
+ if self.is_mamba:
1100
+ hidden_states_sa, present_key_value = self.mixer(
1101
+ hidden_states=hidden_states,
1102
+ attention_mask=attention_mask,
1103
+ past_states=past_state,
1104
+ )
1105
+ self_attn_weights = None
1106
+ else:
1107
+ hidden_states_sa, self_attn_weights, present_key_value = self.mixer(
1108
+ hidden_states=hidden_states,
1109
+ attention_mask=attention_mask,
1110
+ past_states=past_state,
1111
+ output_attentions=output_attentions,
1112
+ )
1113
+
1114
+ hidden_states_sa = self.post_mixer_norm(hidden_states_sa)
1115
+ hidden_states = residual + hidden_states_sa
1116
+
1117
+ residual = hidden_states
1118
+ hidden_states = self.pre_mlp_norm(hidden_states)
1119
+
1120
+ # Fully Connected
1121
+ hidden_states_mlp = self.mlp(hidden_states)
1122
+
1123
+ # Residual
1124
+ hidden_states_mlp = self.post_mlp_norm(hidden_states_mlp)
1125
+ hidden_states = residual + hidden_states_mlp
1126
+
1127
+ outputs: Any = (hidden_states,)
1128
+
1129
+ if output_attentions:
1130
+ outputs += (self_attn_weights,)
1131
+
1132
+ return outputs # type: ignore
1133
+
1134
+
1135
+ def is_mamba(config: PlamoConfig, i: int) -> bool:
1136
+ if not config.mamba_enabled:
1137
+ return False
1138
+ assert config.mamba_step > 1
1139
+ assert i < config.num_hidden_layers
1140
+
1141
+ if config.num_hidden_layers <= (config.mamba_step // 2):
1142
+ # use attention in last layer
1143
+ return i != config.num_hidden_layers - 1
1144
+ return (i % config.mamba_step) != (config.mamba_step // 2)
1145
+
1146
+
1147
+ class PlamoDecoder(torch.nn.Module):
1148
+ def __init__(self, config: PlamoConfig) -> None:
1149
+ super().__init__()
1150
+
1151
+ self.layers = torch.nn.ModuleList(
1152
+ [
1153
+ PlamoDecoderLayer(config, is_mamba=is_mamba(config, i), layer_idx=i)
1154
+ for i in range(config.num_hidden_layers)
1155
+ ]
1156
+ )
1157
+
1158
+ def forward(self, x: DecoderInput) -> DecoderOutput:
1159
+ all_hidden_states: Optional[Tuple[torch.Tensor, ...]] = () if x.output_hidden_states else None
1160
+ all_self_attns: Optional[Tuple[torch.Tensor, ...]] = () if x.output_attentions else None
1161
+ hidden_states = x.hidden_states
1162
+
1163
+ for decoder_layer in self.layers:
1164
+ if x.output_hidden_states:
1165
+ assert all_hidden_states is not None
1166
+ all_hidden_states += (hidden_states,)
1167
+
1168
+ if self.training and x.gradient_checkpointing:
1169
+
1170
+ def create_custom_forward(module): # type: ignore
1171
+ def custom_forward(*inputs): # type: ignore
1172
+ # None for past_key_value
1173
+ return module(*inputs, x.output_attentions, None)
1174
+
1175
+ return custom_forward
1176
+
1177
+ layer_outputs = torch.utils.checkpoint.checkpoint(
1178
+ create_custom_forward(decoder_layer), # type: ignore
1179
+ hidden_states,
1180
+ x.attention_mask,
1181
+ None,
1182
+ )
1183
+ else:
1184
+ layer_outputs = decoder_layer(
1185
+ hidden_states,
1186
+ attention_mask=x.attention_mask,
1187
+ past_state=x.past_states,
1188
+ output_attentions=x.output_attentions,
1189
+ )
1190
+
1191
+ hidden_states = layer_outputs[0]
1192
+
1193
+ if x.output_attentions:
1194
+ assert layer_outputs[1] is not None
1195
+ assert all_self_attns is not None
1196
+ all_self_attns += (layer_outputs[1],)
1197
+ return DecoderOutput(hidden_states, all_hidden_states, all_self_attns)
1198
+
1199
+
1200
+ class PlamoPreTrainedModel(PreTrainedModel): # type: ignore
1201
+ config_class = PlamoConfig
1202
+ _no_split_modules: List[str]
1203
+ base_model_prefix = "model"
1204
+ supports_gradient_checkpointing = True
1205
+ _no_split_modules = ["PlamoDecoderLayer"]
1206
+ _skip_keys_device_placement = "past_key_values"
1207
+ _keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
1208
+
1209
+ def _init_weights(self, module: torch.nn.Module) -> None:
1210
+ std = 0.02
1211
+ if isinstance(module, nn.Linear):
1212
+ module.weight.data.normal_(mean=0.0, std=std)
1213
+ if module.bias is not None:
1214
+ module.bias.data.zero_()
1215
+ elif isinstance(module, nn.Embedding):
1216
+ module.weight.data.normal_(mean=0.0, std=std)
1217
+ if module.padding_idx is not None:
1218
+ module.weight.data[module.padding_idx].zero_()
1219
+
1220
+ def _set_gradient_checkpointing(self, module: torch.nn.Module, value: bool = False) -> None:
1221
+ module.gradient_checkpointing = value # type: ignore
1222
+
1223
+
1224
+ class PlamoModel(PlamoPreTrainedModel):
1225
+ def __init__(self, config: PlamoConfig):
1226
+ super().__init__(config)
1227
+ assert config.eval_attention_n_bit is None
1228
+ assert config.eval_mlp_n_bit is None
1229
+
1230
+ self.padding_idx = config.pad_token_id
1231
+ self.vocab_size = config.vocab_size
1232
+
1233
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
1234
+ if config.image_feature_size is not None:
1235
+ if config.image_proj_type == "mlp":
1236
+ self.image_proj = MLPImageProjector(config) # type: ignore
1237
+ elif config.image_proj_type == "linear":
1238
+ self.image_proj = nn.Linear(config.image_feature_size, config.hidden_size, bias=False) # type: ignore
1239
+ else:
1240
+ raise ValueError(f"Unknown image_proj_type: {config.image_proj_type}")
1241
+ self.layers = PlamoDecoder(config) # type: ignore
1242
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1243
+
1244
+ self.gradient_checkpointing = False
1245
+ # Initialize weights and apply final processing
1246
+ self.post_init()
1247
+
1248
+ def get_input_embeddings(self) -> torch.nn.Embedding:
1249
+ return self.embed_tokens
1250
+
1251
+ def set_input_embeddings(self, value: torch.nn.Embedding) -> None:
1252
+ self.embed_tokens = value
1253
+
1254
+ # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
1255
+ def _prepare_decoder_attention_mask(
1256
+ self,
1257
+ attention_mask: torch.Tensor,
1258
+ input_shape: Tuple[int, int],
1259
+ inputs_embeds: Optional[torch.Tensor],
1260
+ past_key_values_length: int,
1261
+ ) -> Optional[torch.Tensor]:
1262
+ # create causal mask
1263
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1264
+ combined_attention_mask: Optional[torch.Tensor] = None
1265
+ if input_shape[-1] > 1:
1266
+ assert inputs_embeds is not None
1267
+ combined_attention_mask = _make_causal_mask(
1268
+ input_shape,
1269
+ inputs_embeds.dtype,
1270
+ device=inputs_embeds.device,
1271
+ past_key_values_length=past_key_values_length,
1272
+ )
1273
+ input_shape = (input_shape[0], combined_attention_mask.shape[2])
1274
+
1275
+ if attention_mask is not None:
1276
+ if attention_mask.dim() == 4:
1277
+ # Custom 4D attention mask
1278
+ expanded_attn_mask = attention_mask
1279
+ else:
1280
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1281
+ assert inputs_embeds is not None
1282
+ expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
1283
+ inputs_embeds.device
1284
+ )
1285
+ combined_attention_mask = (
1286
+ expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
1287
+ )
1288
+
1289
+ return combined_attention_mask
1290
+
1291
+ def forward(
1292
+ self,
1293
+ input_ids: Optional[torch.LongTensor] = None,
1294
+ attention_mask: Optional[torch.Tensor] = None,
1295
+ position_ids: Optional[torch.Tensor] = None,
1296
+ past_key_values: Optional[PlamoCache] = None,
1297
+ inputs_embeds: Optional[torch.Tensor] = None,
1298
+ image_features: Optional[torch.Tensor] = None,
1299
+ use_cache: Optional[bool] = None,
1300
+ output_attentions: Optional[bool] = None,
1301
+ output_hidden_states: Optional[bool] = None,
1302
+ return_dict: Optional[bool] = None,
1303
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
1304
+ assert input_ids is not None
1305
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1306
+ output_hidden_states = (
1307
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1308
+ )
1309
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1310
+
1311
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1312
+
1313
+ # retrieve input_ids and inputs_embeds
1314
+ if input_ids is not None and inputs_embeds is not None:
1315
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
1316
+ elif input_ids is not None:
1317
+ batch_size, seq_length = input_ids.shape
1318
+ else:
1319
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
1320
+
1321
+ seq_length_with_past = seq_length
1322
+ past_key_values_length = 0
1323
+
1324
+ if past_key_values is not None:
1325
+ past_key_values_length = past_key_values.get_seq_length()
1326
+ seq_length_with_past = seq_length_with_past + past_key_values_length
1327
+
1328
+ if inputs_embeds is None:
1329
+ inputs_embeds = self.embed_tokens(input_ids)
1330
+
1331
+ if image_features is not None:
1332
+ assert self.config.image_token_id is not None
1333
+ image_embeds = self.image_proj(image_features)
1334
+ assert image_embeds.shape == inputs_embeds.shape, (image_embeds.shape, inputs_embeds.shape)
1335
+ mask = input_ids == self.config.image_token_id
1336
+ inputs_embeds[mask] = image_embeds[mask]
1337
+
1338
+ # embed positions
1339
+ require_attn_mask = False
1340
+ if not self.training or past_key_values is not None:
1341
+ require_attn_mask = True
1342
+ if seq_length_with_past >= self.config.attention_window_size:
1343
+ require_attn_mask = True
1344
+ if require_attn_mask and attention_mask is None:
1345
+ attention_mask = torch.ones(
1346
+ (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
1347
+ )
1348
+ if attention_mask is not None:
1349
+ attention_mask = self._prepare_decoder_attention_mask(
1350
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
1351
+ )
1352
+
1353
+ hidden_states = inputs_embeds
1354
+
1355
+ if self.gradient_checkpointing and self.training:
1356
+ if use_cache:
1357
+ use_cache = False
1358
+
1359
+ if use_cache and past_key_values is None:
1360
+ past_key_values = PlamoCache(self.config)
1361
+
1362
+ # decoder layers
1363
+ out = self.layers(
1364
+ DecoderInput(
1365
+ hidden_states,
1366
+ attention_mask,
1367
+ past_key_values,
1368
+ output_hidden_states,
1369
+ output_attentions,
1370
+ self.gradient_checkpointing,
1371
+ )
1372
+ )
1373
+ assert isinstance(out, DecoderOutput)
1374
+ hidden_states = out.hidden_states
1375
+ all_hidden_states = out.all_hidden_states
1376
+ all_self_attns = out.all_self_attns
1377
+
1378
+ hidden_states = self.norm(hidden_states)
1379
+
1380
+ # add hidden states from the last decoder layer
1381
+ if output_hidden_states:
1382
+ assert all_hidden_states is not None
1383
+ all_hidden_states += (hidden_states,)
1384
+
1385
+ if not return_dict:
1386
+ return tuple(
1387
+ v for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns] if v is not None
1388
+ )
1389
+ return BaseModelOutputWithPast(
1390
+ last_hidden_state=hidden_states,
1391
+ past_key_values=past_key_values,
1392
+ hidden_states=all_hidden_states,
1393
+ attentions=all_self_attns,
1394
+ )
1395
+
1396
+
1397
+ class PlamoForCausalLM(PlamoPreTrainedModel):
1398
+ _tied_weights_keys = ["lm_head.weight"]
1399
+
1400
+ # Without this, the model cannot be loaded into a meta device.
1401
+ # Relevant code:
1402
+ # https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/modeling_utils.py#L4376-L4381
1403
+ # https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/modeling_utils.py#L356
1404
+ # https://github.com/pytorch/pytorch/blob/v2.4.1/torch/nn/modules/module.py#L2068
1405
+ _supports_param_buffer_assignment = False
1406
+
1407
+ def __init__(self, config: PlamoConfig) -> None:
1408
+ super().__init__(config)
1409
+ self.model = PlamoModel(config)
1410
+
1411
+ self.vocab_size = config.vocab_size
1412
+ vocab_size = ((self.vocab_size + 15) // 16) * 16
1413
+ self.lm_head: torch.nn.Module = nn.Linear(config.hidden_size, vocab_size, bias=False)
1414
+
1415
+ # Initialize weights and apply final processing
1416
+ self.post_init()
1417
+
1418
+ def get_input_embeddings(self) -> torch.nn.Embedding:
1419
+ return self.model.embed_tokens
1420
+
1421
+ def set_input_embeddings(self, value: torch.nn.Embedding) -> None:
1422
+ self.model.embed_tokens = value
1423
+
1424
+ def get_output_embeddings(self) -> torch.nn.Module:
1425
+ return self.lm_head
1426
+
1427
+ def set_output_embeddings(self, new_embeddings: torch.nn.Module) -> None:
1428
+ self.lm_head = new_embeddings
1429
+
1430
+ def set_decoder(self, decoder: PlamoModel) -> None:
1431
+ self.model = decoder
1432
+
1433
+ def get_decoder(self) -> PlamoModel:
1434
+ return self.model
1435
+
1436
+ def forward( # type: ignore
1437
+ self,
1438
+ input_ids: Optional[torch.LongTensor] = None,
1439
+ attention_mask: Optional[torch.Tensor] = None,
1440
+ position_ids: Optional[torch.Tensor] = None,
1441
+ past_key_values: Optional[PlamoCache] = None,
1442
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1443
+ image_features: Optional[torch.Tensor] = None,
1444
+ labels: Optional[torch.LongTensor] = None,
1445
+ use_cache: Optional[bool] = None,
1446
+ output_attentions: Optional[bool] = None,
1447
+ output_hidden_states: Optional[bool] = None,
1448
+ return_dict: Optional[bool] = None,
1449
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
1450
+ r"""
1451
+ Args:
1452
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1453
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1454
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1455
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1456
+
1457
+ Returns:
1458
+
1459
+ Example:
1460
+
1461
+ ```python
1462
+ >>> from transformers import AutoTokenizer, LlamaForCausalLM
1463
+
1464
+ >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
1465
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
1466
+
1467
+ >>> prompt = "Hey, are you consciours? Can you talk to me?"
1468
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1469
+
1470
+ >>> # Generate
1471
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1472
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1473
+ "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
1474
+ ```"""
1475
+ assert input_ids is not None
1476
+
1477
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1478
+ output_hidden_states = (
1479
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1480
+ )
1481
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1482
+
1483
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1484
+ outputs = self.model(
1485
+ input_ids=input_ids,
1486
+ attention_mask=attention_mask,
1487
+ position_ids=position_ids,
1488
+ past_key_values=past_key_values,
1489
+ inputs_embeds=inputs_embeds,
1490
+ image_features=image_features,
1491
+ use_cache=use_cache,
1492
+ output_attentions=output_attentions,
1493
+ output_hidden_states=output_hidden_states,
1494
+ return_dict=return_dict,
1495
+ )
1496
+
1497
+ hidden_states = outputs[0]
1498
+ logits = self.lm_head(hidden_states)
1499
+ logits = logits[..., : self.vocab_size]
1500
+
1501
+ loss = None
1502
+ if labels is not None:
1503
+ # Shift so that tokens < n predict n
1504
+ shift_logits = logits[..., :-1, :].contiguous()
1505
+ shift_labels = labels[..., 1:].contiguous()
1506
+ # Flatten the tokens
1507
+ loss_fct = nn.CrossEntropyLoss()
1508
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1509
+ shift_labels = shift_labels.view(-1)
1510
+ # Enable model parallelism
1511
+ shift_labels = shift_labels.to(shift_logits.device)
1512
+ loss = loss_fct(shift_logits, shift_labels)
1513
+
1514
+ if not return_dict:
1515
+ output = (logits,) + outputs[1:]
1516
+ return (loss,) + output if loss is not None else output
1517
+
1518
+ return CausalLMOutputWithPast(
1519
+ loss=loss,
1520
+ logits=logits,
1521
+ past_key_values=outputs.past_key_values,
1522
+ hidden_states=outputs.hidden_states,
1523
+ attentions=outputs.attentions,
1524
+ )
1525
+
1526
+ def prepare_inputs_for_generation(
1527
+ self,
1528
+ input_ids: torch.Tensor,
1529
+ past_key_values: Optional[PlamoCache] = None,
1530
+ attention_mask: Optional[torch.Tensor] = None,
1531
+ inputs_embeds: Optional[torch.Tensor] = None,
1532
+ image_features: Optional[torch.Tensor] = None,
1533
+ **kwargs: Any,
1534
+ ) -> Dict[str, Any]:
1535
+ if past_key_values:
1536
+ input_ids = input_ids[:, -1:]
1537
+ if image_features is not None:
1538
+ image_features = image_features[:, -1:, :]
1539
+
1540
+ position_ids = kwargs.get("position_ids", None)
1541
+ if attention_mask is not None and position_ids is None:
1542
+ # create position_ids on the fly for batch generation
1543
+ position_ids = attention_mask.long().cumsum(-1) - 1
1544
+ position_ids.masked_fill_(attention_mask == 0, 1)
1545
+ if past_key_values:
1546
+ position_ids = position_ids[:, -1].unsqueeze(-1)
1547
+
1548
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1549
+ if inputs_embeds is not None and past_key_values is None:
1550
+ model_inputs: Dict[str, Any] = {"inputs_embeds": inputs_embeds}
1551
+ else:
1552
+ model_inputs = {"input_ids": input_ids}
1553
+
1554
+ model_inputs.update(
1555
+ {
1556
+ "position_ids": position_ids,
1557
+ "past_key_values": past_key_values,
1558
+ "use_cache": kwargs.get("use_cache"),
1559
+ "attention_mask": attention_mask,
1560
+ "image_features": image_features,
1561
+ }
1562
+ )
1563
+ return model_inputs
1564
+
1565
+ @staticmethod
1566
+ def _reorder_cache(past_key_values: PlamoCache, beam_idx: torch.Tensor) -> PlamoCache:
1567
+ past_key_values.reorder_cache(beam_idx)
1568
+ return past_key_values
1569
+
1570
+
1571
+ class MLPImageProjector(nn.Module):
1572
+ def __init__(self, config: PlamoConfig) -> None:
1573
+ super().__init__()
1574
+ self.config = config
1575
+
1576
+ assert config.image_feature_size is not None # for typing
1577
+
1578
+ # nn.LayerNorm is not supported by PFVM, so use RMSNorm + Bias instead to approximate this.
1579
+ self.norm0 = RMSNorm(config.image_feature_size, eps=config.rms_norm_eps)
1580
+ self.bias0 = Bias(config.image_feature_size)
1581
+
1582
+ # PFVM doesn't support Linear with bias, so add bias manually afterwards.
1583
+ self.linear1 = nn.Linear(config.image_feature_size, config.hidden_size, bias=False)
1584
+ self.bias1 = Bias(config.hidden_size)
1585
+ self.act1 = nn.GELU()
1586
+
1587
+ self.linear2 = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
1588
+ self.bias2 = Bias(config.hidden_size)
1589
+
1590
+ def forward(
1591
+ self,
1592
+ hidden_states: torch.Tensor,
1593
+ ) -> torch.Tensor:
1594
+ hidden_states = self.norm0(hidden_states)
1595
+ hidden_states = self.bias0(hidden_states)
1596
+
1597
+ hidden_states = self.linear1(hidden_states)
1598
+ hidden_states = self.bias1(hidden_states)
1599
+ hidden_states = self.act1(hidden_states)
1600
+
1601
+ hidden_states = self.linear2(hidden_states)
1602
+ hidden_states = self.bias2(hidden_states)
1603
+
1604
+ return hidden_states
1605
+
1606
+
1607
+ class Bias(nn.Module):
1608
+ def __init__(self, num_features: int) -> None:
1609
+ super().__init__()
1610
+ self._bias = nn.Parameter(torch.zeros((num_features,)))
1611
+
1612
+ def forward(
1613
+ self,
1614
+ x: torch.Tensor,
1615
+ ) -> torch.Tensor:
1616
+ return x + self._bias
special_tokens_map.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<|plamo:bos|>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "<|plamo:eos|>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": {
17
+ "content": "<|plamo:pad|>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "unk_token": {
24
+ "content": "<|plamo:unk|>",
25
+ "lstrip": false,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ }
30
+ }
tokenization_plamo.py ADDED
@@ -0,0 +1,392 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import math
3
+ import os
4
+ from shutil import copyfile
5
+ from typing import Any, Optional, Tuple
6
+
7
+ import numpy as np
8
+
9
+ # NOTE: numba does not support type hints for njit: https://github.com/python/mypy/issues/16149
10
+ from numba import njit # type: ignore[attr-defined]
11
+ from numba.core import types
12
+ from numba.typed import Dict, List
13
+ from transformers.tokenization_utils import PreTrainedTokenizer
14
+ from transformers.utils import logging
15
+
16
+ VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.jsonl"}
17
+ logger = logging.get_logger(__name__)
18
+
19
+ INVALID_SCORE = -20000000
20
+ UNKNOWN_SCORE = -10000000
21
+
22
+ TABLE_PIECE_LENGTH = 0
23
+ TABLE_TOKEN_ID = 1
24
+ TABLE_SCORE = 2
25
+ TABLE_PIECE_ID = 3
26
+
27
+ PATH_TOKEN_LENGTH = 0
28
+ PATH_TOKEN_ID = 1
29
+ PATH_NUM_TOKENS = 2
30
+
31
+
32
+ class AhoCorasick:
33
+ def __init__(self) -> None:
34
+ # List of tokens in the vocabulary.
35
+ self._tokens: list[str]
36
+
37
+ # A mapping from a byte code point to a token ID, used for byte fallback.
38
+ self._bytes: np.ndarray
39
+
40
+ # A mapping from a suffix's piece code to a suffix ID.
41
+ #
42
+ # Typically, the Aho-Corasick algorithm builds a Trie and adds suffix links between nodes
43
+ # of the Trie. In this implementation, a suffix ID corresponds to a node in the trie, and
44
+ # a piece code to an edge (in other words, a pair of a node and the next character).
45
+ #
46
+ # A piece code is a 64-bit integer:
47
+ # - The upper 32 bits store the Unicode code point of the first character.
48
+ # - The lower 32 bits store the suffix ID of the remaining suffix.
49
+ #
50
+ # A suffix ID is an integer indicating the starting position in the _table.
51
+ self._to_suffix_id: Dict[types.int64, types.int32]
52
+
53
+ # Flattened table representing the Trie structure for the Aho-Corasick algorithm.
54
+ # It stores information including scores for each piece (prefix) within each suffix.
55
+ # It is flattened for memory efficiency and performance. Suffixes are stored in
56
+ # lexicographical order of their reversed strings, which improves memory access locality
57
+ # when exploring new characters starting from the string's end. Pieces within a suffix are
58
+ # stored in the decreasing order of their lengths.
59
+ #
60
+ # Each piece (a prefix fo the suffix) contains four pieces of information:
61
+ # - TABLE_PIECE_LENGTH: Length of the piece.
62
+ # - TABLE_TOKEN_ID: Token ID (or -1 if the piece is not a valid token).
63
+ # - TABLE_SCORE: Score (or INVALID_SCORE if the piece is not a valid token).
64
+ # - TABLE_PIECE_ID: Piece ID of the suffix.
65
+ #
66
+ # Each suffix also includes a sentinel row with a length of 1, a score of UNKNOWN_SCORE,
67
+ # and a token ID of -1. Sentinel rows are identified by the score being UNKNOWN_SCORE.
68
+ self._table: np.ndarray
69
+
70
+ def build(self, vocab: list[Any]) -> None:
71
+ self._bytes = np.zeros(256, dtype=np.int32)
72
+ self._to_suffix_id = Dict.empty(key_type=types.int64, value_type=types.int32)
73
+
74
+ # Build suffix_to_score and token_to_token_id.
75
+ # The suffix_to_score dictionary maps a suffix to its score. It also includes all suffixes
76
+ # of the token for the Trie structure for the Aho-Corasick algorithm. If a suffix is not a
77
+ # valid token, its score is set to math.nan.
78
+ # The token_to_token_id dictionary maps a token to its token ID.
79
+ suffix_to_score: dict[str, float] = {}
80
+ token_to_token_id: dict[str, int] = {}
81
+ self._tokens = []
82
+ for token_id, row in enumerate(vocab):
83
+ assert isinstance(row[0], str), row
84
+ assert isinstance(row[1], (int, float)), row
85
+
86
+ token = str(row[0])
87
+ self._tokens.append(token)
88
+ token_to_token_id[token] = token_id
89
+
90
+ # Special handling for byte tokens.
91
+ if len(row) > 2 and row[2] == "BYTE":
92
+ assert len(token) == 6 and token.startswith("<0x") and token.endswith(">"), row[0]
93
+ self._bytes[int(row[0][3:5], 16)] = token_id
94
+ continue
95
+
96
+ suffix_to_score[token] = float(row[1])
97
+ # Ensure that all suffixes are included in suffix_to_score.
98
+ for i in range(1, len(token)):
99
+ suffix_to_score[token[i:]] = suffix_to_score.get(token[i:], math.nan)
100
+
101
+ # Ensure all byte tokens are set.
102
+ for i in range(256):
103
+ assert self._bytes[i] != 0, f"Byte token for <0x{i:02X}> is not set."
104
+
105
+ # List suffixes in lexicographical order of their reversed strings.
106
+ suffixes = list(suffix_to_score.keys())
107
+ suffixes.append("")
108
+ suffixes.sort(key=lambda x: x[::-1])
109
+
110
+ # Build suffix_to_id, which is a mapping from a suffix to a suffix ID, and _to_suffix_id,
111
+ # which is a mapping from a piece code to a suffix ID.
112
+ suffix_to_id: dict[str, int] = {}
113
+ num_pieces = 0
114
+ for s in suffixes:
115
+ suffix_to_id[s] = num_pieces
116
+ if s != "":
117
+ self._to_suffix_id[ord(s[0]) << 32 | suffix_to_id[s[1:]]] = np.int32(num_pieces)
118
+ num_pieces += 1 + sum(s[:i] in suffix_to_score for i in range(1, len(s) + 1))
119
+ assert suffix_to_id[""] == 0, suffix_to_id[""]
120
+
121
+ # Build _table, which is a flattened table representing the Trie structure for the Aho-Corasick.
122
+ self._table = np.zeros((num_pieces, 4), dtype=np.int32)
123
+ i = 0
124
+ for suffix in suffixes:
125
+ # Add all prefixes of the suffix to the table.
126
+ for piece_length in range(len(suffix), 0, -1):
127
+ piece = suffix[:piece_length]
128
+ score = suffix_to_score.get(piece, None)
129
+ if score is None:
130
+ continue
131
+ self._table[i, TABLE_PIECE_LENGTH] = piece_length
132
+ self._table[i, TABLE_TOKEN_ID] = token_to_token_id.get(piece, -1)
133
+ self._table[i, TABLE_SCORE] = round(score * 1e4) if math.isfinite(score) else INVALID_SCORE
134
+ self._table[i, TABLE_PIECE_ID] = suffix_to_id[piece]
135
+ i += 1
136
+
137
+ # Add a sentinel row.
138
+ self._table[i, TABLE_PIECE_LENGTH] = 1
139
+ self._table[i, TABLE_TOKEN_ID] = -1
140
+ self._table[i, TABLE_SCORE] = UNKNOWN_SCORE
141
+ i += 1
142
+ assert i == num_pieces, (i, num_pieces)
143
+
144
+ @staticmethod
145
+ @njit
146
+ def _encode(
147
+ to_suffix_id: Dict[types.int64, types.int32],
148
+ table: np.ndarray,
149
+ bytes: np.ndarray,
150
+ data: np.ndarray,
151
+ ) -> np.ndarray:
152
+ # Initialize scores array with a high value and set the score at the end to 0.
153
+ # This array keeps track of the minimum cost (best score) to encode from each position to the end.
154
+ scores = np.full((len(data) + 1,), 2**60, dtype=np.int64)
155
+ scores[-1] = 0
156
+
157
+ # Path array to store the best path information.
158
+ # The path array keeps track of token length, token ID, and number of tokens needed to encode.
159
+ path = np.zeros((len(data) + 1, 3), dtype=np.int32)
160
+
161
+ # Initialize suffix_id to 0, which represents the root of the Trie.
162
+ suffix_id = 0
163
+
164
+ # Process the input data from the end to the beginning.
165
+ for i in range(len(data) - 1, -1, -1):
166
+ c = data[i]
167
+
168
+ # Find the next suffix ID by iterating the suffix IDs of prefixes of the current suffix.
169
+ # NOTE: If no suffix ID is found, suffix_id will be set to 0.
170
+ for p in range(suffix_id, len(table)):
171
+ suffix_id = to_suffix_id.get(c << 32 | table[p, TABLE_PIECE_ID], np.int32(0))
172
+ # If a next suffix ID is found or a sentinel row is reached, break the loop.
173
+ if suffix_id > 0 or table[p, TABLE_SCORE] == UNKNOWN_SCORE:
174
+ break
175
+
176
+ # Update the best path to the current position. If multiple paths have the same score,
177
+ # this chooses the longest prefix as the best path (table is sorted in the decreasing
178
+ # order of piece length).
179
+ for p in range(suffix_id, len(table)):
180
+ score = table[p, TABLE_SCORE]
181
+ if score > INVALID_SCORE:
182
+ piece_length = table[p, TABLE_PIECE_LENGTH]
183
+ s = scores[i + piece_length] - score
184
+ if s < scores[i]:
185
+ scores[i] = s
186
+ path[i, PATH_TOKEN_LENGTH] = piece_length
187
+ path[i, PATH_TOKEN_ID] = table[p, TABLE_TOKEN_ID]
188
+ path[i, PATH_NUM_TOKENS] = path[i + piece_length, PATH_NUM_TOKENS] + 1
189
+ if score == UNKNOWN_SCORE:
190
+ # Add number of bytes to represent `c` in UTF-8 (minus 1; 1 is already
191
+ # added above).
192
+ path[i, PATH_NUM_TOKENS] += (c >= 0x80) + (c >= 0x800) + (c >= 0x10000)
193
+
194
+ # If it reaches a sentinel row, break the loop.
195
+ if score == UNKNOWN_SCORE:
196
+ break
197
+
198
+ # Decode the best path from the beginning to get the token IDs.
199
+ pos = 0
200
+ token_ids = np.zeros(path[0, PATH_NUM_TOKENS], dtype=np.int32)
201
+ token_pos = 0
202
+ while pos < len(data):
203
+ if path[pos, PATH_TOKEN_ID] >= 0:
204
+ token_ids[token_pos] = path[pos, PATH_TOKEN_ID]
205
+ token_pos += 1
206
+ else:
207
+ # Fall back to byte tokens.
208
+ c = data[pos]
209
+ s = 1 + (c >= 0x80) + (c >= 0x800) + (c >= 0x10000)
210
+ # Add byte tokens representing UTF-8 bytes.
211
+ for i in range(s):
212
+ b = c if s == 1 else (0xF00 >> s) & 0xFF if i == 0 else 0x80
213
+ token_ids[token_pos] = bytes[b | ((c >> (s - i - 1) * 6) & 0x3F)]
214
+ token_pos += 1
215
+
216
+ # Ensure that pos should increase by at least 1.
217
+ assert path[pos, PATH_TOKEN_LENGTH] > 0, (pos, path[pos])
218
+ pos += path[pos, PATH_TOKEN_LENGTH]
219
+
220
+ return token_ids
221
+
222
+ def encode(self, data: str) -> np.ndarray:
223
+ """Encodes a string into a sequence of token IDs."""
224
+ return np.asarray(
225
+ self._encode(
226
+ self._to_suffix_id,
227
+ self._table,
228
+ self._bytes,
229
+ # Convert a string into a numpy array of Unicode code points.
230
+ # NOTE: This skips UTF-32 BOM.
231
+ np.frombuffer(data.encode("utf-32"), dtype=np.int32)[1:],
232
+ )
233
+ )
234
+
235
+ def encode_as_tokens(self, data: str) -> list[str]:
236
+ """Encodes a string into a sequence of tokens."""
237
+ return [self._tokens[token_id] for token_id in self.encode(data)]
238
+
239
+
240
+ class PlamoTokenizer(PreTrainedTokenizer): # type: ignore
241
+ vocab_files_names = VOCAB_FILES_NAMES
242
+ model_input_names = ["input_ids", "attention_mask"]
243
+
244
+ _save_files = [
245
+ "special_tokens_map.json",
246
+ "tokenization_plamo.py",
247
+ "tokenizer.jsonl",
248
+ "tokenizer_config.json",
249
+ ]
250
+
251
+ def __init__(
252
+ self,
253
+ vocab_file: str,
254
+ unk_token: str = "<|plamo:unk|>",
255
+ bos_token: str = "<|plamo:bos|>",
256
+ eos_token: str = "<|plamo:eos|>",
257
+ pad_token: str = "<|plamo:pad|>",
258
+ cls_token: Optional[str] = None,
259
+ sep_token: Optional[str] = None,
260
+ mask_token: Optional[str] = None,
261
+ clean_up_tokenization_spaces: bool = False,
262
+ **kwargs: Any,
263
+ ) -> None:
264
+ """Tokenizer for PLaMo.
265
+
266
+ Args:
267
+ vocab_file (str): Vocabrary file path.
268
+ unk_token (str): Unknown token.
269
+ bos_token (str): Beginning of sentence token.
270
+ eos_token (str): End of sentence token.
271
+ pad_token (str): Padding token.
272
+ cls_token (str):
273
+ Classification token, to extract a summary of an input sequence leveraging self-attention along the
274
+ full depth of the model.
275
+ sep_token (str): Separation token, to separate context and query in an input sequence.
276
+ mask_token (str): Mask token, to use when training a model with masked-language modeling.
277
+ clean_up_tokenization_spaces (bool): Whether or not to clean up the tokenization spaces.
278
+ num_threads (int):
279
+ Number of threads. This value will be ignored if one of `PLAMO_TOKENIZER_NUM_THREADS` or
280
+ `RAYON_NUM_THREADS` is set as an environment variable.
281
+ """
282
+ if "add_bos_token" not in kwargs:
283
+ kwargs["add_bos_token"] = False
284
+ if "add_eos_token" not in kwargs:
285
+ kwargs["add_eos_token"] = False
286
+ self.data: list[Any] = [json.loads(line) for line in open(vocab_file, "r", encoding="utf-8")]
287
+ self.vocab: dict[str, int] = {v[0]: i for i, v in enumerate(self.data)}
288
+ self.aho_corasick = AhoCorasick()
289
+ self.aho_corasick.build(self.data)
290
+ self.vocab_file = vocab_file
291
+ self.add_bos_token = kwargs["add_bos_token"]
292
+ self.add_eos_token = kwargs["add_eos_token"]
293
+
294
+ super().__init__(
295
+ vocab_file=vocab_file,
296
+ unk_token=unk_token,
297
+ bos_token=bos_token,
298
+ eos_token=eos_token,
299
+ pad_token=pad_token,
300
+ cls_token=cls_token,
301
+ sep_token=sep_token,
302
+ mask_token=mask_token,
303
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
304
+ **kwargs,
305
+ )
306
+
307
+ # the functions below are copied from hf transformers LlamaTokenizer's implementation to fix the behaviour of the tokenizer
308
+ # https://github.com/huggingface/transformers/blob/v4.30.2/src/transformers/models/llama/tokenization_llama.py
309
+
310
+ def __getstate__(self) -> dict[str, Any]:
311
+ state = self.__dict__.copy()
312
+ state["aho_corasick"] = None
313
+ return state
314
+
315
+ def __setstate__(self, d: dict[str, Any]) -> None:
316
+ self.__dict__ = d
317
+ self.aho_corasick = AhoCorasick()
318
+ self.aho_corasick.build(self.data)
319
+
320
+ @property
321
+ def vocab_size(self) -> Any:
322
+ """Returns vocab size"""
323
+ return len(self.data)
324
+
325
+ def token_to_score(self, token: str) -> Optional[float]:
326
+ """Returns score of the token"""
327
+ token_id = self.vocab.get(token, None)
328
+ return None if token_id is None else self.data[token_id][1]
329
+
330
+ def get_vocab(self) -> dict[str, int]:
331
+ """Returns vocab as a dict"""
332
+ vocab = self.vocab.copy()
333
+ vocab.update(self.added_tokens_encoder)
334
+ return vocab
335
+
336
+ def convert_tokens_to_string(self, tokens: List[str]) -> str:
337
+ """Converts a sequence of tokens (string) in a single string."""
338
+ return b"".join(
339
+ [bytes([int(t[3:5], 16)]) if t.startswith("<0x") else t.encode("utf-8") for t in tokens]
340
+ ).decode("utf-8", errors="replace")
341
+
342
+ def _tokenize(self, text: str) -> Any:
343
+ """Returns a tokenized string."""
344
+ return self.aho_corasick.encode_as_tokens(text)
345
+
346
+ def _convert_token_to_id(self, token: str) -> Any:
347
+ """Converts a token (str) in an id using the vocab."""
348
+ return self.vocab.get(token, 0)
349
+
350
+ def _convert_id_to_token(self, index: int) -> Any:
351
+ """Converts an index (integer) in a token (str) using the vocab."""
352
+ return self.data[index][0]
353
+
354
+ def build_inputs_with_special_tokens(
355
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
356
+ ) -> List[int]:
357
+ bos_token_id = [self.bos_token_id] if self.add_bos_token else []
358
+ eos_token_id = [self.eos_token_id] if self.add_eos_token else []
359
+
360
+ output = bos_token_id + token_ids_0 + eos_token_id
361
+
362
+ if token_ids_1 is not None:
363
+ output = output + bos_token_id + token_ids_1 + eos_token_id
364
+
365
+ return output
366
+
367
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
368
+ """
369
+ Save the vocabulary and special tokens file to a directory.
370
+
371
+ Args:
372
+ save_directory (`str`):
373
+ The directory in which to save the vocabulary.
374
+
375
+ Returns:
376
+ `Tuple(str)`: Paths to the files saved.
377
+ """
378
+ if not os.path.isdir(save_directory):
379
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
380
+ return ("",)
381
+ out_vocab_file = os.path.join(
382
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
383
+ )
384
+
385
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
386
+ copyfile(self.vocab_file, out_vocab_file)
387
+ elif not os.path.isfile(self.vocab_file):
388
+ with open(out_vocab_file, "w") as f:
389
+ for token in self.data:
390
+ print(json.dumps(token, ensure_ascii=False), file=f)
391
+
392
+ return (out_vocab_file,)
tokenizer.jsonl ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": true,
3
+ "add_eos_token": false,
4
+ "added_tokens_decoder": {
5
+ "0": {
6
+ "content": "<|plamo:unk|>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false,
11
+ "special": true
12
+ },
13
+ "1": {
14
+ "content": "<|plamo:bos|>",
15
+ "lstrip": false,
16
+ "normalized": false,
17
+ "rstrip": false,
18
+ "single_word": false,
19
+ "special": true
20
+ },
21
+ "2": {
22
+ "content": "<|plamo:eos|>",
23
+ "lstrip": false,
24
+ "normalized": false,
25
+ "rstrip": false,
26
+ "single_word": false,
27
+ "special": true
28
+ },
29
+ "3": {
30
+ "content": "<|plamo:pad|>",
31
+ "lstrip": false,
32
+ "normalized": false,
33
+ "rstrip": false,
34
+ "single_word": false,
35
+ "special": true
36
+ }
37
+ },
38
+ "auto_map": {
39
+ "AutoTokenizer": [
40
+ "tokenization_plamo.PlamoTokenizer",
41
+ null
42
+ ]
43
+ },
44
+ "bos_token": "<|plamo:bos|>",
45
+ "clean_up_tokenization_spaces": false,
46
+ "cls_token": null,
47
+ "eos_token": "<|plamo:eos|>",
48
+ "local_file_only": true,
49
+ "mask_token": null,
50
+ "model_max_length": 1000000000000000019884624838656,
51
+ "pad_token": "<|plamo:pad|>",
52
+ "sep_token": null,
53
+ "tokenizer_class": "PlamoTokenizer",
54
+ "unk_token": "<|plamo:unk|>"
55
+ }