Commit
·
65176a6
0
Parent(s):
initial commit
Browse files- .gitattributes +35 -0
- LICENSE +202 -0
- README.md +111 -0
- config.json +49 -0
- generation_config.json +6 -0
- model-00001-of-00007.safetensors +3 -0
- model-00002-of-00007.safetensors +3 -0
- model-00003-of-00007.safetensors +3 -0
- model-00004-of-00007.safetensors +3 -0
- model-00005-of-00007.safetensors +3 -0
- model-00006-of-00007.safetensors +3 -0
- model-00007-of-00007.safetensors +3 -0
- model.safetensors.index.json +223 -0
- modeling_plamo.py +1616 -0
- special_tokens_map.json +30 -0
- tokenization_plamo.py +392 -0
- tokenizer.jsonl +0 -0
- tokenizer_config.json +55 -0
.gitattributes
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
LICENSE
ADDED
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
Apache License
|
3 |
+
Version 2.0, January 2004
|
4 |
+
http://www.apache.org/licenses/
|
5 |
+
|
6 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
7 |
+
|
8 |
+
1. Definitions.
|
9 |
+
|
10 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
11 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
12 |
+
|
13 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
14 |
+
the copyright owner that is granting the License.
|
15 |
+
|
16 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
17 |
+
other entities that control, are controlled by, or are under common
|
18 |
+
control with that entity. For the purposes of this definition,
|
19 |
+
"control" means (i) the power, direct or indirect, to cause the
|
20 |
+
direction or management of such entity, whether by contract or
|
21 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
22 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
23 |
+
|
24 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
25 |
+
exercising permissions granted by this License.
|
26 |
+
|
27 |
+
"Source" form shall mean the preferred form for making modifications,
|
28 |
+
including but not limited to software source code, documentation
|
29 |
+
source, and configuration files.
|
30 |
+
|
31 |
+
"Object" form shall mean any form resulting from mechanical
|
32 |
+
transformation or translation of a Source form, including but
|
33 |
+
not limited to compiled object code, generated documentation,
|
34 |
+
and conversions to other media types.
|
35 |
+
|
36 |
+
"Work" shall mean the work of authorship, whether in Source or
|
37 |
+
Object form, made available under the License, as indicated by a
|
38 |
+
copyright notice that is included in or attached to the work
|
39 |
+
(an example is provided in the Appendix below).
|
40 |
+
|
41 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
42 |
+
form, that is based on (or derived from) the Work and for which the
|
43 |
+
editorial revisions, annotations, elaborations, or other modifications
|
44 |
+
represent, as a whole, an original work of authorship. For the purposes
|
45 |
+
of this License, Derivative Works shall not include works that remain
|
46 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
47 |
+
the Work and Derivative Works thereof.
|
48 |
+
|
49 |
+
"Contribution" shall mean any work of authorship, including
|
50 |
+
the original version of the Work and any modifications or additions
|
51 |
+
to that Work or Derivative Works thereof, that is intentionally
|
52 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
53 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
54 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
55 |
+
means any form of electronic, verbal, or written communication sent
|
56 |
+
to the Licensor or its representatives, including but not limited to
|
57 |
+
communication on electronic mailing lists, source code control systems,
|
58 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
59 |
+
Licensor for the purpose of discussing and improving the Work, but
|
60 |
+
excluding communication that is conspicuously marked or otherwise
|
61 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
62 |
+
|
63 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
64 |
+
on behalf of whom a Contribution has been received by Licensor and
|
65 |
+
subsequently incorporated within the Work.
|
66 |
+
|
67 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
68 |
+
this License, each Contributor hereby grants to You a perpetual,
|
69 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
70 |
+
copyright license to reproduce, prepare Derivative Works of,
|
71 |
+
publicly display, publicly perform, sublicense, and distribute the
|
72 |
+
Work and such Derivative Works in Source or Object form.
|
73 |
+
|
74 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
75 |
+
this License, each Contributor hereby grants to You a perpetual,
|
76 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
77 |
+
(except as stated in this section) patent license to make, have made,
|
78 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
79 |
+
where such license applies only to those patent claims licensable
|
80 |
+
by such Contributor that are necessarily infringed by their
|
81 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
82 |
+
with the Work to which such Contribution(s) was submitted. If You
|
83 |
+
institute patent litigation against any entity (including a
|
84 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
85 |
+
or a Contribution incorporated within the Work constitutes direct
|
86 |
+
or contributory patent infringement, then any patent licenses
|
87 |
+
granted to You under this License for that Work shall terminate
|
88 |
+
as of the date such litigation is filed.
|
89 |
+
|
90 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
91 |
+
Work or Derivative Works thereof in any medium, with or without
|
92 |
+
modifications, and in Source or Object form, provided that You
|
93 |
+
meet the following conditions:
|
94 |
+
|
95 |
+
(a) You must give any other recipients of the Work or
|
96 |
+
Derivative Works a copy of this License; and
|
97 |
+
|
98 |
+
(b) You must cause any modified files to carry prominent notices
|
99 |
+
stating that You changed the files; and
|
100 |
+
|
101 |
+
(c) You must retain, in the Source form of any Derivative Works
|
102 |
+
that You distribute, all copyright, patent, trademark, and
|
103 |
+
attribution notices from the Source form of the Work,
|
104 |
+
excluding those notices that do not pertain to any part of
|
105 |
+
the Derivative Works; and
|
106 |
+
|
107 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
108 |
+
distribution, then any Derivative Works that You distribute must
|
109 |
+
include a readable copy of the attribution notices contained
|
110 |
+
within such NOTICE file, excluding those notices that do not
|
111 |
+
pertain to any part of the Derivative Works, in at least one
|
112 |
+
of the following places: within a NOTICE text file distributed
|
113 |
+
as part of the Derivative Works; within the Source form or
|
114 |
+
documentation, if provided along with the Derivative Works; or,
|
115 |
+
within a display generated by the Derivative Works, if and
|
116 |
+
wherever such third-party notices normally appear. The contents
|
117 |
+
of the NOTICE file are for informational purposes only and
|
118 |
+
do not modify the License. You may add Your own attribution
|
119 |
+
notices within Derivative Works that You distribute, alongside
|
120 |
+
or as an addendum to the NOTICE text from the Work, provided
|
121 |
+
that such additional attribution notices cannot be construed
|
122 |
+
as modifying the License.
|
123 |
+
|
124 |
+
You may add Your own copyright statement to Your modifications and
|
125 |
+
may provide additional or different license terms and conditions
|
126 |
+
for use, reproduction, or distribution of Your modifications, or
|
127 |
+
for any such Derivative Works as a whole, provided Your use,
|
128 |
+
reproduction, and distribution of the Work otherwise complies with
|
129 |
+
the conditions stated in this License.
|
130 |
+
|
131 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
132 |
+
any Contribution intentionally submitted for inclusion in the Work
|
133 |
+
by You to the Licensor shall be under the terms and conditions of
|
134 |
+
this License, without any additional terms or conditions.
|
135 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
136 |
+
the terms of any separate license agreement you may have executed
|
137 |
+
with Licensor regarding such Contributions.
|
138 |
+
|
139 |
+
6. Trademarks. This License does not grant permission to use the trade
|
140 |
+
names, trademarks, service marks, or product names of the Licensor,
|
141 |
+
except as required for reasonable and customary use in describing the
|
142 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
143 |
+
|
144 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
145 |
+
agreed to in writing, Licensor provides the Work (and each
|
146 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
147 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
148 |
+
implied, including, without limitation, any warranties or conditions
|
149 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
150 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
151 |
+
appropriateness of using or redistributing the Work and assume any
|
152 |
+
risks associated with Your exercise of permissions under this License.
|
153 |
+
|
154 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
155 |
+
whether in tort (including negligence), contract, or otherwise,
|
156 |
+
unless required by applicable law (such as deliberate and grossly
|
157 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
158 |
+
liable to You for damages, including any direct, indirect, special,
|
159 |
+
incidental, or consequential damages of any character arising as a
|
160 |
+
result of this License or out of the use or inability to use the
|
161 |
+
Work (including but not limited to damages for loss of goodwill,
|
162 |
+
work stoppage, computer failure or malfunction, or any and all
|
163 |
+
other commercial damages or losses), even if such Contributor
|
164 |
+
has been advised of the possibility of such damages.
|
165 |
+
|
166 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
167 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
168 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
169 |
+
or other liability obligations and/or rights consistent with this
|
170 |
+
License. However, in accepting such obligations, You may act only
|
171 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
172 |
+
of any other Contributor, and only if You agree to indemnify,
|
173 |
+
defend, and hold each Contributor harmless for any liability
|
174 |
+
incurred by, or claims asserted against, such Contributor by reason
|
175 |
+
of your accepting any such warranty or additional liability.
|
176 |
+
|
177 |
+
END OF TERMS AND CONDITIONS
|
178 |
+
|
179 |
+
APPENDIX: How to apply the Apache License to your work.
|
180 |
+
|
181 |
+
To apply the Apache License to your work, attach the following
|
182 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
183 |
+
replaced with your own identifying information. (Don't include
|
184 |
+
the brackets!) The text should be enclosed in the appropriate
|
185 |
+
comment syntax for the file format. We also recommend that a
|
186 |
+
file or class name and description of purpose be included on the
|
187 |
+
same "printed page" as the copyright notice for easier
|
188 |
+
identification within third-party archives.
|
189 |
+
|
190 |
+
Copyright 2025 Preferred Elements, Inc.
|
191 |
+
|
192 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
193 |
+
you may not use this file except in compliance with the License.
|
194 |
+
You may obtain a copy of the License at
|
195 |
+
|
196 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
197 |
+
|
198 |
+
Unless required by applicable law or agreed to in writing, software
|
199 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
200 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
201 |
+
See the License for the specific language governing permissions and
|
202 |
+
limitations under the License.
|
README.md
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: apache-2.0
|
3 |
+
language:
|
4 |
+
- en
|
5 |
+
- ja
|
6 |
+
pipeline_tag: text-generation
|
7 |
+
library_name: transformers
|
8 |
+
---
|
9 |
+
|
10 |
+
# PLaMo 2 1B
|
11 |
+
|
12 |
+
|
13 |
+
## Model Description
|
14 |
+
PLaMo 2 1B is a 1B model pre-trained on English and Japanese datasets, developed by Preferred Elements, Inc.
|
15 |
+
|
16 |
+
PLaMo 2 1B is released under Apache License version 2.0.
|
17 |
+
|
18 |
+
**NOTE**: This model has **NOT** been instruction-tuned for chat dialog or other downstream tasks.
|
19 |
+
|
20 |
+
|
21 |
+
## Usage
|
22 |
+
|
23 |
+
### Requirements
|
24 |
+
|
25 |
+
```
|
26 |
+
numpy>=1.26.4
|
27 |
+
numba>=0.60.0
|
28 |
+
torch>=2.4.1
|
29 |
+
transformers>=4.44.2
|
30 |
+
mamba_ssm>=2.2.2
|
31 |
+
causal_conv1d>=1.4.0
|
32 |
+
```
|
33 |
+
|
34 |
+
### Use a pipeline as a high-level helper
|
35 |
+
|
36 |
+
```python
|
37 |
+
import transformers
|
38 |
+
pipeline = transformers.pipeline("text-generation", model="pfnet/plamo-2-1b", trust_remote_code=True)
|
39 |
+
print(pipeline("The future of artificial intelligence technology is ", max_new_tokens=32))
|
40 |
+
```
|
41 |
+
|
42 |
+
### Load model directly
|
43 |
+
|
44 |
+
```python
|
45 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
46 |
+
tokenizer = AutoTokenizer.from_pretrained("pfnet/plamo-2-1b", trust_remote_code=True)
|
47 |
+
model = AutoModelForCausalLM.from_pretrained("pfnet/plamo-2-1b", trust_remote_code=True)
|
48 |
+
text = "これからの人工知能技術は"
|
49 |
+
input_ids = tokenizer(text, return_tensors="pt").input_ids
|
50 |
+
generated_tokens = model.generate(
|
51 |
+
inputs=input_ids,
|
52 |
+
max_new_tokens=32,
|
53 |
+
do_sample=True,
|
54 |
+
top_k=50,
|
55 |
+
top_p=0.95,
|
56 |
+
temperature=1.0,
|
57 |
+
)[0]
|
58 |
+
generated_text = tokenizer.decode(generated_tokens)
|
59 |
+
print(generated_text)
|
60 |
+
```
|
61 |
+
|
62 |
+
|
63 |
+
## Model Details
|
64 |
+
|
65 |
+
- Model size: 1B
|
66 |
+
- Trained tokens: 4T tokens
|
67 |
+
- Developed by: Preferred Elements, Inc.
|
68 |
+
- Model type: Causal decoder-only
|
69 |
+
- Language(s): English, Japanese
|
70 |
+
- License: Apache License version 2.0
|
71 |
+
|
72 |
+
|
73 |
+
## Training Dataset
|
74 |
+
|
75 |
+
We trained PLaMo 2 1B in two phases, phase 1 with 3.5T tokens and phase 2 with 0.5T tokens.
|
76 |
+
The percentage of datasets in each phase is shown in the following table.
|
77 |
+
|
78 |
+
||3.5T (phase 1)|0.5T (phase 2)|Tokens|
|
79 |
+
|---|:---:|:---:|:---:|
|
80 |
+
|English|45 %|35 %|1.75 T|
|
81 |
+
|Japanese|30 %|40 %|1.25 T|
|
82 |
+
|Coding|15 %|15 %|0.6 T|
|
83 |
+
|Other|10 %|10 %|0.4 T|
|
84 |
+
|
85 |
+
|
86 |
+
## Tokenizer
|
87 |
+
|
88 |
+
PLaMo 2 1B tokenizer is optimized by numba, which is JIT compiler for numerical functions.
|
89 |
+
The tokenizer is trained on a subset of the datasets for model pre-training.
|
90 |
+
|
91 |
+
|
92 |
+
## Tech Blog
|
93 |
+
|
94 |
+
- (JA) https://tech.preferred.jp/ja/blog/plamo-2/
|
95 |
+
- (JA) https://tech.preferred.jp/ja/blog/plamo-2-tokenizer/
|
96 |
+
|
97 |
+
|
98 |
+
## Bias, Risks, and Limitations
|
99 |
+
|
100 |
+
PLaMo 2 1B is a new technology that carries risks with use. Testing conducted to date has been in English and Japanese, and has not covered, nor could it cover all scenarios. For these reasons, as with all LLMs, PLaMo 2 1B’s potential outputs cannot be predicted in advance, and the model may in some instances produce inaccurate, biased or other objectionable responses to user prompts. Therefore, before deploying any applications of PLaMo 2 1B, developers should perform safety testing and tuning tailored to their specific applications of the model.
|
101 |
+
|
102 |
+
|
103 |
+
## Acknowledgement
|
104 |
+
|
105 |
+
This model is trained under the project, “Research and Development Project of the Enhanced Infrastructures for Post 5G Information and Communication System” (JPNP 20017), subsidized by the New Energy and Industrial Technology Development Organization (NEDO).
|
106 |
+
|
107 |
+
|
108 |
+
## AI policies for Preferred Networks, Inc. group
|
109 |
+
|
110 |
+
- (EN) https://www.preferred.jp/en/company/aipolicy/
|
111 |
+
- (JA) https://www.preferred.jp/ja/company/aipolicy/
|
config.json
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"architectures": [
|
3 |
+
"PlamoForCausalLM"
|
4 |
+
],
|
5 |
+
"attention_window_size": 2048,
|
6 |
+
"auto_map": {
|
7 |
+
"AutoConfig": "modeling_plamo.PlamoConfig",
|
8 |
+
"AutoModelForCausalLM": "modeling_plamo.PlamoForCausalLM"
|
9 |
+
},
|
10 |
+
"bos_token_id": 1,
|
11 |
+
"capacity_factor": 1.0,
|
12 |
+
"eos_token_id": 2,
|
13 |
+
"eval_attention_n_bit": null,
|
14 |
+
"eval_mlp_n_bit": null,
|
15 |
+
"expert_dropout": 0.0,
|
16 |
+
"fp8_accum_dtype": "bfloat16",
|
17 |
+
"group_size": 1024,
|
18 |
+
"hidden_size": 2048,
|
19 |
+
"hidden_size_per_head": 128,
|
20 |
+
"image_feature_size": null,
|
21 |
+
"image_proj_type": "linear",
|
22 |
+
"image_token_id": null,
|
23 |
+
"intermediate_size": 8192,
|
24 |
+
"k_expert": null,
|
25 |
+
"linear_type": "fp8",
|
26 |
+
"mamba_chunk_size": 256,
|
27 |
+
"mamba_d_conv": 4,
|
28 |
+
"mamba_d_state": 64,
|
29 |
+
"mamba_enabled": true,
|
30 |
+
"mamba_num_heads": 32,
|
31 |
+
"mamba_step": 2,
|
32 |
+
"max_position_embeddings": 10485760,
|
33 |
+
"model_type": "plamo",
|
34 |
+
"n_expert": null,
|
35 |
+
"num_attention_heads": 16,
|
36 |
+
"num_hidden_layers": 16,
|
37 |
+
"num_key_value_heads": 1,
|
38 |
+
"rms_norm_eps": 1e-06,
|
39 |
+
"shared_intermediate_size": null,
|
40 |
+
"sliding_window": 2048,
|
41 |
+
"sparse_intermediate_size": null,
|
42 |
+
"sparse_step": null,
|
43 |
+
"tokenizer_class": "PlamoTokenizer",
|
44 |
+
"torch_dtype": "float32",
|
45 |
+
"transformers_version": "4.44.2",
|
46 |
+
"use_cache": true,
|
47 |
+
"use_predefined_initial_state": false,
|
48 |
+
"vocab_size": 100000
|
49 |
+
}
|
generation_config.json
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_from_model_config": true,
|
3 |
+
"bos_token_id": 1,
|
4 |
+
"eos_token_id": 2,
|
5 |
+
"transformers_version": "4.44.2"
|
6 |
+
}
|
model-00001-of-00007.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:306bfd09ff22b5c044abfbc310b0a916a85a450ff52beaa90e245bb67b4cb1e9
|
3 |
+
size 4060618320
|
model-00002-of-00007.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:759d54a62b8f8124031bf4048b2b7e9b9fe3439e75ce772e48562369768836a8
|
3 |
+
size 285214592
|
model-00003-of-00007.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c7172210250c2b4030b1bc9c75a68224844e770838159e939bb6c68593d2e235
|
3 |
+
size 819200136
|
model-00004-of-00007.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:017d164149a8e31acdde2f8c6dd45568d7a8900951f83cb5fbcebac64c662695
|
3 |
+
size 621448
|
model-00005-of-00007.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6bfe7b7d45c40c75d7ff7b362fd88bf34e23aedf6a637df4eafc9f7d6231f8ca
|
3 |
+
size 131976
|
model-00006-of-00007.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3dadc185925d0be7f495c51924546b0f07f93550cf4222da96dedd335beb3b8f
|
3 |
+
size 5296
|
model-00007-of-00007.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:60d95b10b6e140a9626a7058d5038528f2ff80148dc4569b881db56052046509
|
3 |
+
size 40
|
model.safetensors.index.json
ADDED
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"weight_map": {
|
3 |
+
"model.layers.layers.0.mixer.in_proj.weight": "model-00001-of-00007.safetensors",
|
4 |
+
"model.layers.layers.0.mixer.conv1d.weight": "model-00001-of-00007.safetensors",
|
5 |
+
"model.layers.layers.0.mixer.bcdt_proj.weight": "model-00001-of-00007.safetensors",
|
6 |
+
"model.layers.layers.0.mixer.out_proj.weight": "model-00001-of-00007.safetensors",
|
7 |
+
"model.layers.layers.2.mixer.in_proj.weight": "model-00001-of-00007.safetensors",
|
8 |
+
"model.layers.layers.2.mixer.conv1d.weight": "model-00001-of-00007.safetensors",
|
9 |
+
"model.layers.layers.2.mixer.bcdt_proj.weight": "model-00001-of-00007.safetensors",
|
10 |
+
"model.layers.layers.2.mixer.out_proj.weight": "model-00001-of-00007.safetensors",
|
11 |
+
"model.layers.layers.4.mixer.in_proj.weight": "model-00001-of-00007.safetensors",
|
12 |
+
"model.layers.layers.4.mixer.conv1d.weight": "model-00001-of-00007.safetensors",
|
13 |
+
"model.layers.layers.4.mixer.bcdt_proj.weight": "model-00001-of-00007.safetensors",
|
14 |
+
"model.layers.layers.4.mixer.out_proj.weight": "model-00001-of-00007.safetensors",
|
15 |
+
"model.layers.layers.6.mixer.in_proj.weight": "model-00001-of-00007.safetensors",
|
16 |
+
"model.layers.layers.6.mixer.conv1d.weight": "model-00001-of-00007.safetensors",
|
17 |
+
"model.layers.layers.6.mixer.bcdt_proj.weight": "model-00001-of-00007.safetensors",
|
18 |
+
"model.layers.layers.6.mixer.out_proj.weight": "model-00001-of-00007.safetensors",
|
19 |
+
"model.layers.layers.8.mixer.in_proj.weight": "model-00001-of-00007.safetensors",
|
20 |
+
"model.layers.layers.8.mixer.conv1d.weight": "model-00001-of-00007.safetensors",
|
21 |
+
"model.layers.layers.8.mixer.bcdt_proj.weight": "model-00001-of-00007.safetensors",
|
22 |
+
"model.layers.layers.8.mixer.out_proj.weight": "model-00001-of-00007.safetensors",
|
23 |
+
"model.layers.layers.10.mixer.in_proj.weight": "model-00001-of-00007.safetensors",
|
24 |
+
"model.layers.layers.10.mixer.conv1d.weight": "model-00001-of-00007.safetensors",
|
25 |
+
"model.layers.layers.10.mixer.bcdt_proj.weight": "model-00001-of-00007.safetensors",
|
26 |
+
"model.layers.layers.10.mixer.out_proj.weight": "model-00001-of-00007.safetensors",
|
27 |
+
"model.layers.layers.12.mixer.in_proj.weight": "model-00001-of-00007.safetensors",
|
28 |
+
"model.layers.layers.12.mixer.conv1d.weight": "model-00001-of-00007.safetensors",
|
29 |
+
"model.layers.layers.12.mixer.bcdt_proj.weight": "model-00001-of-00007.safetensors",
|
30 |
+
"model.layers.layers.12.mixer.out_proj.weight": "model-00001-of-00007.safetensors",
|
31 |
+
"model.layers.layers.14.mixer.in_proj.weight": "model-00001-of-00007.safetensors",
|
32 |
+
"model.layers.layers.14.mixer.conv1d.weight": "model-00001-of-00007.safetensors",
|
33 |
+
"model.layers.layers.14.mixer.bcdt_proj.weight": "model-00001-of-00007.safetensors",
|
34 |
+
"model.layers.layers.14.mixer.out_proj.weight": "model-00001-of-00007.safetensors",
|
35 |
+
"model.layers.layers.1.mixer.qkv_proj.weight": "model-00002-of-00007.safetensors",
|
36 |
+
"model.layers.layers.1.mixer.o_proj.weight": "model-00002-of-00007.safetensors",
|
37 |
+
"model.layers.layers.3.mixer.qkv_proj.weight": "model-00002-of-00007.safetensors",
|
38 |
+
"model.layers.layers.3.mixer.o_proj.weight": "model-00002-of-00007.safetensors",
|
39 |
+
"model.layers.layers.5.mixer.qkv_proj.weight": "model-00002-of-00007.safetensors",
|
40 |
+
"model.layers.layers.5.mixer.o_proj.weight": "model-00002-of-00007.safetensors",
|
41 |
+
"model.layers.layers.7.mixer.qkv_proj.weight": "model-00002-of-00007.safetensors",
|
42 |
+
"model.layers.layers.7.mixer.o_proj.weight": "model-00002-of-00007.safetensors",
|
43 |
+
"model.layers.layers.9.mixer.qkv_proj.weight": "model-00002-of-00007.safetensors",
|
44 |
+
"model.layers.layers.9.mixer.o_proj.weight": "model-00002-of-00007.safetensors",
|
45 |
+
"model.layers.layers.11.mixer.qkv_proj.weight": "model-00002-of-00007.safetensors",
|
46 |
+
"model.layers.layers.11.mixer.o_proj.weight": "model-00002-of-00007.safetensors",
|
47 |
+
"model.layers.layers.13.mixer.qkv_proj.weight": "model-00002-of-00007.safetensors",
|
48 |
+
"model.layers.layers.13.mixer.o_proj.weight": "model-00002-of-00007.safetensors",
|
49 |
+
"model.layers.layers.15.mixer.qkv_proj.weight": "model-00002-of-00007.safetensors",
|
50 |
+
"model.layers.layers.15.mixer.o_proj.weight": "model-00002-of-00007.safetensors",
|
51 |
+
"model.embed_tokens.weight": "model-00003-of-00007.safetensors",
|
52 |
+
"model.layers.layers.0.mixer.dt_norm_weight": "model-00004-of-00007.safetensors",
|
53 |
+
"model.layers.layers.0.mixer.B_norm_weight": "model-00004-of-00007.safetensors",
|
54 |
+
"model.layers.layers.0.mixer.C_norm_weight": "model-00004-of-00007.safetensors",
|
55 |
+
"model.layers.layers.0.pre_mixer_norm.weight": "model-00004-of-00007.safetensors",
|
56 |
+
"model.layers.layers.0.post_mixer_norm.weight": "model-00004-of-00007.safetensors",
|
57 |
+
"model.layers.layers.0.pre_mlp_norm.weight": "model-00004-of-00007.safetensors",
|
58 |
+
"model.layers.layers.0.post_mlp_norm.weight": "model-00004-of-00007.safetensors",
|
59 |
+
"model.layers.layers.1.mixer.q_weight": "model-00004-of-00007.safetensors",
|
60 |
+
"model.layers.layers.1.mixer.k_weight": "model-00004-of-00007.safetensors",
|
61 |
+
"model.layers.layers.1.pre_mixer_norm.weight": "model-00004-of-00007.safetensors",
|
62 |
+
"model.layers.layers.1.post_mixer_norm.weight": "model-00004-of-00007.safetensors",
|
63 |
+
"model.layers.layers.1.pre_mlp_norm.weight": "model-00004-of-00007.safetensors",
|
64 |
+
"model.layers.layers.1.post_mlp_norm.weight": "model-00004-of-00007.safetensors",
|
65 |
+
"model.layers.layers.2.mixer.dt_norm_weight": "model-00004-of-00007.safetensors",
|
66 |
+
"model.layers.layers.2.mixer.B_norm_weight": "model-00004-of-00007.safetensors",
|
67 |
+
"model.layers.layers.2.mixer.C_norm_weight": "model-00004-of-00007.safetensors",
|
68 |
+
"model.layers.layers.2.pre_mixer_norm.weight": "model-00004-of-00007.safetensors",
|
69 |
+
"model.layers.layers.2.post_mixer_norm.weight": "model-00004-of-00007.safetensors",
|
70 |
+
"model.layers.layers.2.pre_mlp_norm.weight": "model-00004-of-00007.safetensors",
|
71 |
+
"model.layers.layers.2.post_mlp_norm.weight": "model-00004-of-00007.safetensors",
|
72 |
+
"model.layers.layers.3.mixer.q_weight": "model-00004-of-00007.safetensors",
|
73 |
+
"model.layers.layers.3.mixer.k_weight": "model-00004-of-00007.safetensors",
|
74 |
+
"model.layers.layers.3.pre_mixer_norm.weight": "model-00004-of-00007.safetensors",
|
75 |
+
"model.layers.layers.3.post_mixer_norm.weight": "model-00004-of-00007.safetensors",
|
76 |
+
"model.layers.layers.3.pre_mlp_norm.weight": "model-00004-of-00007.safetensors",
|
77 |
+
"model.layers.layers.3.post_mlp_norm.weight": "model-00004-of-00007.safetensors",
|
78 |
+
"model.layers.layers.4.mixer.dt_norm_weight": "model-00004-of-00007.safetensors",
|
79 |
+
"model.layers.layers.4.mixer.B_norm_weight": "model-00004-of-00007.safetensors",
|
80 |
+
"model.layers.layers.4.mixer.C_norm_weight": "model-00004-of-00007.safetensors",
|
81 |
+
"model.layers.layers.4.pre_mixer_norm.weight": "model-00004-of-00007.safetensors",
|
82 |
+
"model.layers.layers.4.post_mixer_norm.weight": "model-00004-of-00007.safetensors",
|
83 |
+
"model.layers.layers.4.pre_mlp_norm.weight": "model-00004-of-00007.safetensors",
|
84 |
+
"model.layers.layers.4.post_mlp_norm.weight": "model-00004-of-00007.safetensors",
|
85 |
+
"model.layers.layers.5.mixer.q_weight": "model-00004-of-00007.safetensors",
|
86 |
+
"model.layers.layers.5.mixer.k_weight": "model-00004-of-00007.safetensors",
|
87 |
+
"model.layers.layers.5.pre_mixer_norm.weight": "model-00004-of-00007.safetensors",
|
88 |
+
"model.layers.layers.5.post_mixer_norm.weight": "model-00004-of-00007.safetensors",
|
89 |
+
"model.layers.layers.5.pre_mlp_norm.weight": "model-00004-of-00007.safetensors",
|
90 |
+
"model.layers.layers.5.post_mlp_norm.weight": "model-00004-of-00007.safetensors",
|
91 |
+
"model.layers.layers.6.mixer.dt_norm_weight": "model-00004-of-00007.safetensors",
|
92 |
+
"model.layers.layers.6.mixer.B_norm_weight": "model-00004-of-00007.safetensors",
|
93 |
+
"model.layers.layers.6.mixer.C_norm_weight": "model-00004-of-00007.safetensors",
|
94 |
+
"model.layers.layers.6.pre_mixer_norm.weight": "model-00004-of-00007.safetensors",
|
95 |
+
"model.layers.layers.6.post_mixer_norm.weight": "model-00004-of-00007.safetensors",
|
96 |
+
"model.layers.layers.6.pre_mlp_norm.weight": "model-00004-of-00007.safetensors",
|
97 |
+
"model.layers.layers.6.post_mlp_norm.weight": "model-00004-of-00007.safetensors",
|
98 |
+
"model.layers.layers.7.mixer.q_weight": "model-00004-of-00007.safetensors",
|
99 |
+
"model.layers.layers.7.mixer.k_weight": "model-00004-of-00007.safetensors",
|
100 |
+
"model.layers.layers.7.pre_mixer_norm.weight": "model-00004-of-00007.safetensors",
|
101 |
+
"model.layers.layers.7.post_mixer_norm.weight": "model-00004-of-00007.safetensors",
|
102 |
+
"model.layers.layers.7.pre_mlp_norm.weight": "model-00004-of-00007.safetensors",
|
103 |
+
"model.layers.layers.7.post_mlp_norm.weight": "model-00004-of-00007.safetensors",
|
104 |
+
"model.layers.layers.8.mixer.dt_norm_weight": "model-00004-of-00007.safetensors",
|
105 |
+
"model.layers.layers.8.mixer.B_norm_weight": "model-00004-of-00007.safetensors",
|
106 |
+
"model.layers.layers.8.mixer.C_norm_weight": "model-00004-of-00007.safetensors",
|
107 |
+
"model.layers.layers.8.pre_mixer_norm.weight": "model-00004-of-00007.safetensors",
|
108 |
+
"model.layers.layers.8.post_mixer_norm.weight": "model-00004-of-00007.safetensors",
|
109 |
+
"model.layers.layers.8.pre_mlp_norm.weight": "model-00004-of-00007.safetensors",
|
110 |
+
"model.layers.layers.8.post_mlp_norm.weight": "model-00004-of-00007.safetensors",
|
111 |
+
"model.layers.layers.9.mixer.q_weight": "model-00004-of-00007.safetensors",
|
112 |
+
"model.layers.layers.9.mixer.k_weight": "model-00004-of-00007.safetensors",
|
113 |
+
"model.layers.layers.9.pre_mixer_norm.weight": "model-00004-of-00007.safetensors",
|
114 |
+
"model.layers.layers.9.post_mixer_norm.weight": "model-00004-of-00007.safetensors",
|
115 |
+
"model.layers.layers.9.pre_mlp_norm.weight": "model-00004-of-00007.safetensors",
|
116 |
+
"model.layers.layers.9.post_mlp_norm.weight": "model-00004-of-00007.safetensors",
|
117 |
+
"model.layers.layers.10.mixer.dt_norm_weight": "model-00004-of-00007.safetensors",
|
118 |
+
"model.layers.layers.10.mixer.B_norm_weight": "model-00004-of-00007.safetensors",
|
119 |
+
"model.layers.layers.10.mixer.C_norm_weight": "model-00004-of-00007.safetensors",
|
120 |
+
"model.layers.layers.10.pre_mixer_norm.weight": "model-00004-of-00007.safetensors",
|
121 |
+
"model.layers.layers.10.post_mixer_norm.weight": "model-00004-of-00007.safetensors",
|
122 |
+
"model.layers.layers.10.pre_mlp_norm.weight": "model-00004-of-00007.safetensors",
|
123 |
+
"model.layers.layers.10.post_mlp_norm.weight": "model-00004-of-00007.safetensors",
|
124 |
+
"model.layers.layers.11.mixer.q_weight": "model-00004-of-00007.safetensors",
|
125 |
+
"model.layers.layers.11.mixer.k_weight": "model-00004-of-00007.safetensors",
|
126 |
+
"model.layers.layers.11.pre_mixer_norm.weight": "model-00004-of-00007.safetensors",
|
127 |
+
"model.layers.layers.11.post_mixer_norm.weight": "model-00004-of-00007.safetensors",
|
128 |
+
"model.layers.layers.11.pre_mlp_norm.weight": "model-00004-of-00007.safetensors",
|
129 |
+
"model.layers.layers.11.post_mlp_norm.weight": "model-00004-of-00007.safetensors",
|
130 |
+
"model.layers.layers.12.mixer.dt_norm_weight": "model-00004-of-00007.safetensors",
|
131 |
+
"model.layers.layers.12.mixer.B_norm_weight": "model-00004-of-00007.safetensors",
|
132 |
+
"model.layers.layers.12.mixer.C_norm_weight": "model-00004-of-00007.safetensors",
|
133 |
+
"model.layers.layers.12.pre_mixer_norm.weight": "model-00004-of-00007.safetensors",
|
134 |
+
"model.layers.layers.12.post_mixer_norm.weight": "model-00004-of-00007.safetensors",
|
135 |
+
"model.layers.layers.12.pre_mlp_norm.weight": "model-00004-of-00007.safetensors",
|
136 |
+
"model.layers.layers.12.post_mlp_norm.weight": "model-00004-of-00007.safetensors",
|
137 |
+
"model.layers.layers.13.mixer.q_weight": "model-00004-of-00007.safetensors",
|
138 |
+
"model.layers.layers.13.mixer.k_weight": "model-00004-of-00007.safetensors",
|
139 |
+
"model.layers.layers.13.pre_mixer_norm.weight": "model-00004-of-00007.safetensors",
|
140 |
+
"model.layers.layers.13.post_mixer_norm.weight": "model-00004-of-00007.safetensors",
|
141 |
+
"model.layers.layers.13.pre_mlp_norm.weight": "model-00004-of-00007.safetensors",
|
142 |
+
"model.layers.layers.13.post_mlp_norm.weight": "model-00004-of-00007.safetensors",
|
143 |
+
"model.layers.layers.14.mixer.dt_norm_weight": "model-00004-of-00007.safetensors",
|
144 |
+
"model.layers.layers.14.mixer.B_norm_weight": "model-00004-of-00007.safetensors",
|
145 |
+
"model.layers.layers.14.mixer.C_norm_weight": "model-00004-of-00007.safetensors",
|
146 |
+
"model.layers.layers.14.pre_mixer_norm.weight": "model-00004-of-00007.safetensors",
|
147 |
+
"model.layers.layers.14.post_mixer_norm.weight": "model-00004-of-00007.safetensors",
|
148 |
+
"model.layers.layers.14.pre_mlp_norm.weight": "model-00004-of-00007.safetensors",
|
149 |
+
"model.layers.layers.14.post_mlp_norm.weight": "model-00004-of-00007.safetensors",
|
150 |
+
"model.layers.layers.15.mixer.q_weight": "model-00004-of-00007.safetensors",
|
151 |
+
"model.layers.layers.15.mixer.k_weight": "model-00004-of-00007.safetensors",
|
152 |
+
"model.layers.layers.15.pre_mixer_norm.weight": "model-00004-of-00007.safetensors",
|
153 |
+
"model.layers.layers.15.post_mixer_norm.weight": "model-00004-of-00007.safetensors",
|
154 |
+
"model.layers.layers.15.pre_mlp_norm.weight": "model-00004-of-00007.safetensors",
|
155 |
+
"model.layers.layers.15.post_mlp_norm.weight": "model-00004-of-00007.safetensors",
|
156 |
+
"model.norm.weight": "model-00004-of-00007.safetensors",
|
157 |
+
"model.layers.layers.0.mixer.dt_proj.weight": "model-00005-of-00007.safetensors",
|
158 |
+
"model.layers.layers.2.mixer.dt_proj.weight": "model-00005-of-00007.safetensors",
|
159 |
+
"model.layers.layers.4.mixer.dt_proj.weight": "model-00005-of-00007.safetensors",
|
160 |
+
"model.layers.layers.6.mixer.dt_proj.weight": "model-00005-of-00007.safetensors",
|
161 |
+
"model.layers.layers.8.mixer.dt_proj.weight": "model-00005-of-00007.safetensors",
|
162 |
+
"model.layers.layers.10.mixer.dt_proj.weight": "model-00005-of-00007.safetensors",
|
163 |
+
"model.layers.layers.12.mixer.dt_proj.weight": "model-00005-of-00007.safetensors",
|
164 |
+
"model.layers.layers.14.mixer.dt_proj.weight": "model-00005-of-00007.safetensors",
|
165 |
+
"model.layers.layers.0.mlp.gate_up_proj.weight": "model-00001-of-00007.safetensors",
|
166 |
+
"model.layers.layers.0.mlp.down_proj.weight": "model-00001-of-00007.safetensors",
|
167 |
+
"model.layers.layers.1.mlp.gate_up_proj.weight": "model-00001-of-00007.safetensors",
|
168 |
+
"model.layers.layers.1.mlp.down_proj.weight": "model-00001-of-00007.safetensors",
|
169 |
+
"model.layers.layers.2.mlp.gate_up_proj.weight": "model-00001-of-00007.safetensors",
|
170 |
+
"model.layers.layers.2.mlp.down_proj.weight": "model-00001-of-00007.safetensors",
|
171 |
+
"model.layers.layers.3.mlp.gate_up_proj.weight": "model-00001-of-00007.safetensors",
|
172 |
+
"model.layers.layers.3.mlp.down_proj.weight": "model-00001-of-00007.safetensors",
|
173 |
+
"model.layers.layers.4.mlp.gate_up_proj.weight": "model-00001-of-00007.safetensors",
|
174 |
+
"model.layers.layers.4.mlp.down_proj.weight": "model-00001-of-00007.safetensors",
|
175 |
+
"model.layers.layers.5.mlp.gate_up_proj.weight": "model-00001-of-00007.safetensors",
|
176 |
+
"model.layers.layers.5.mlp.down_proj.weight": "model-00001-of-00007.safetensors",
|
177 |
+
"model.layers.layers.6.mlp.gate_up_proj.weight": "model-00001-of-00007.safetensors",
|
178 |
+
"model.layers.layers.6.mlp.down_proj.weight": "model-00001-of-00007.safetensors",
|
179 |
+
"model.layers.layers.7.mlp.gate_up_proj.weight": "model-00001-of-00007.safetensors",
|
180 |
+
"model.layers.layers.7.mlp.down_proj.weight": "model-00001-of-00007.safetensors",
|
181 |
+
"model.layers.layers.8.mlp.gate_up_proj.weight": "model-00001-of-00007.safetensors",
|
182 |
+
"model.layers.layers.8.mlp.down_proj.weight": "model-00001-of-00007.safetensors",
|
183 |
+
"model.layers.layers.9.mlp.gate_up_proj.weight": "model-00001-of-00007.safetensors",
|
184 |
+
"model.layers.layers.9.mlp.down_proj.weight": "model-00001-of-00007.safetensors",
|
185 |
+
"model.layers.layers.10.mlp.gate_up_proj.weight": "model-00001-of-00007.safetensors",
|
186 |
+
"model.layers.layers.10.mlp.down_proj.weight": "model-00001-of-00007.safetensors",
|
187 |
+
"model.layers.layers.11.mlp.gate_up_proj.weight": "model-00001-of-00007.safetensors",
|
188 |
+
"model.layers.layers.11.mlp.down_proj.weight": "model-00001-of-00007.safetensors",
|
189 |
+
"model.layers.layers.12.mlp.gate_up_proj.weight": "model-00001-of-00007.safetensors",
|
190 |
+
"model.layers.layers.12.mlp.down_proj.weight": "model-00001-of-00007.safetensors",
|
191 |
+
"model.layers.layers.13.mlp.gate_up_proj.weight": "model-00001-of-00007.safetensors",
|
192 |
+
"model.layers.layers.13.mlp.down_proj.weight": "model-00001-of-00007.safetensors",
|
193 |
+
"model.layers.layers.14.mlp.gate_up_proj.weight": "model-00001-of-00007.safetensors",
|
194 |
+
"model.layers.layers.14.mlp.down_proj.weight": "model-00001-of-00007.safetensors",
|
195 |
+
"model.layers.layers.15.mlp.gate_up_proj.weight": "model-00001-of-00007.safetensors",
|
196 |
+
"model.layers.layers.15.mlp.down_proj.weight": "model-00001-of-00007.safetensors",
|
197 |
+
"model.layers.layers.0.mixer.dt_bias": "model-00006-of-00007.safetensors",
|
198 |
+
"model.layers.layers.0.mixer.A_log": "model-00006-of-00007.safetensors",
|
199 |
+
"model.layers.layers.0.mixer.D": "model-00006-of-00007.safetensors",
|
200 |
+
"model.layers.layers.2.mixer.dt_bias": "model-00006-of-00007.safetensors",
|
201 |
+
"model.layers.layers.2.mixer.A_log": "model-00006-of-00007.safetensors",
|
202 |
+
"model.layers.layers.2.mixer.D": "model-00006-of-00007.safetensors",
|
203 |
+
"model.layers.layers.4.mixer.dt_bias": "model-00006-of-00007.safetensors",
|
204 |
+
"model.layers.layers.4.mixer.A_log": "model-00006-of-00007.safetensors",
|
205 |
+
"model.layers.layers.4.mixer.D": "model-00006-of-00007.safetensors",
|
206 |
+
"model.layers.layers.6.mixer.dt_bias": "model-00006-of-00007.safetensors",
|
207 |
+
"model.layers.layers.6.mixer.A_log": "model-00006-of-00007.safetensors",
|
208 |
+
"model.layers.layers.6.mixer.D": "model-00006-of-00007.safetensors",
|
209 |
+
"model.layers.layers.8.mixer.dt_bias": "model-00006-of-00007.safetensors",
|
210 |
+
"model.layers.layers.8.mixer.A_log": "model-00006-of-00007.safetensors",
|
211 |
+
"model.layers.layers.8.mixer.D": "model-00006-of-00007.safetensors",
|
212 |
+
"model.layers.layers.10.mixer.dt_bias": "model-00006-of-00007.safetensors",
|
213 |
+
"model.layers.layers.10.mixer.A_log": "model-00006-of-00007.safetensors",
|
214 |
+
"model.layers.layers.10.mixer.D": "model-00006-of-00007.safetensors",
|
215 |
+
"model.layers.layers.12.mixer.dt_bias": "model-00006-of-00007.safetensors",
|
216 |
+
"model.layers.layers.12.mixer.A_log": "model-00006-of-00007.safetensors",
|
217 |
+
"model.layers.layers.12.mixer.D": "model-00006-of-00007.safetensors",
|
218 |
+
"model.layers.layers.14.mixer.dt_bias": "model-00006-of-00007.safetensors",
|
219 |
+
"model.layers.layers.14.mixer.A_log": "model-00006-of-00007.safetensors",
|
220 |
+
"model.layers.layers.14.mixer.D": "model-00006-of-00007.safetensors"
|
221 |
+
},
|
222 |
+
"metadata": {}
|
223 |
+
}
|
modeling_plamo.py
ADDED
@@ -0,0 +1,1616 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import enum
|
2 |
+
import math
|
3 |
+
import warnings
|
4 |
+
from typing import Any, Dict, List, Literal, NamedTuple, Optional, Tuple, Union
|
5 |
+
|
6 |
+
try:
|
7 |
+
# It is difficult to install mamba_ssm in login node because
|
8 |
+
# it requires GPU for installation
|
9 |
+
import mamba_ssm
|
10 |
+
except ModuleNotFoundError:
|
11 |
+
warnings.warn("mamba_ssm could not be imported", stacklevel=2)
|
12 |
+
try:
|
13 |
+
# It is difficult to install causal_conv1d in login node because
|
14 |
+
# it requires GPU for installation
|
15 |
+
import causal_conv1d.causal_conv1d_interface as causal_conv1d
|
16 |
+
except ModuleNotFoundError:
|
17 |
+
warnings.warn("causal_conv1d could not be imported", stacklevel=2)
|
18 |
+
import torch
|
19 |
+
from torch import nn
|
20 |
+
from torch.nn import functional as F
|
21 |
+
from transformers import PretrainedConfig, PreTrainedModel
|
22 |
+
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
23 |
+
|
24 |
+
|
25 |
+
def _is_first_token(mask: torch.Tensor) -> torch.Tensor:
|
26 |
+
assert mask.dtype == torch.bool
|
27 |
+
B, Nh, q_len, kv_len = mask.shape
|
28 |
+
mask = mask[:, :, :, -q_len:]
|
29 |
+
cont = q_len != kv_len
|
30 |
+
v = False if cont else True
|
31 |
+
out = torch.logical_not(torch.diagonal(mask, offset=-1, dim1=-2, dim2=-1).bool())
|
32 |
+
out = torch.cat(
|
33 |
+
[
|
34 |
+
torch.full(size=(B, Nh, 1), dtype=torch.bool, device=out.device, fill_value=v),
|
35 |
+
out,
|
36 |
+
],
|
37 |
+
dim=-1,
|
38 |
+
)
|
39 |
+
return out
|
40 |
+
|
41 |
+
|
42 |
+
def _swiglu(h: torch.Tensor) -> torch.Tensor:
|
43 |
+
h0, h1 = h.chunk(2, dim=-1)
|
44 |
+
return torch.nn.functional.silu(h0) * h1
|
45 |
+
|
46 |
+
|
47 |
+
class RotaryEmbedding(torch.nn.Module):
|
48 |
+
def __init__(
|
49 |
+
self, dim: int, max_position_embeddings: int = 2048, base: int = 10000, device: Optional[torch.device] = None
|
50 |
+
) -> None:
|
51 |
+
super().__init__()
|
52 |
+
|
53 |
+
self.dim = dim
|
54 |
+
self.max_position_embeddings = max_position_embeddings
|
55 |
+
self.base = base
|
56 |
+
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
|
57 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
58 |
+
|
59 |
+
# Build here to make `torch.jit.trace` work.
|
60 |
+
self._set_cos_sin_cache(
|
61 |
+
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
|
62 |
+
)
|
63 |
+
|
64 |
+
def _set_cos_sin_cache(self, seq_len: int, device: Any, dtype: Any) -> None:
|
65 |
+
self.max_seq_len_cached = seq_len
|
66 |
+
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) # type: ignore
|
67 |
+
|
68 |
+
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
69 |
+
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
70 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
71 |
+
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
|
72 |
+
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
|
73 |
+
|
74 |
+
def forward(self, x: torch.Tensor, seq_len: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
75 |
+
# x: [bs, num_attention_heads, seq_len, head_size]
|
76 |
+
if seq_len > self.max_seq_len_cached:
|
77 |
+
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
78 |
+
|
79 |
+
return (
|
80 |
+
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), # type: ignore
|
81 |
+
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), # type: ignore
|
82 |
+
)
|
83 |
+
|
84 |
+
|
85 |
+
def _rotate_half(x: torch.Tensor) -> torch.Tensor:
|
86 |
+
"""Rotates half the hidden dims of the input."""
|
87 |
+
x1 = x[..., : x.shape[-1] // 2]
|
88 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
89 |
+
return torch.cat((-x2, x1), dim=-1)
|
90 |
+
|
91 |
+
|
92 |
+
def _rotary_pos_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, position_ids: torch.Tensor) -> torch.Tensor:
|
93 |
+
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
|
94 |
+
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
|
95 |
+
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
|
96 |
+
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
97 |
+
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
98 |
+
x_embed = (x * cos) + (_rotate_half(x) * sin)
|
99 |
+
return x_embed
|
100 |
+
|
101 |
+
|
102 |
+
class LinearType(str, enum.Enum):
|
103 |
+
Normal = "normal"
|
104 |
+
Fp8 = "fp8"
|
105 |
+
Fp8Retain = "fp8-retain"
|
106 |
+
|
107 |
+
|
108 |
+
class PlamoConfig(PretrainedConfig): # type: ignore
|
109 |
+
model_type: str = "plamo"
|
110 |
+
|
111 |
+
def __init__(
|
112 |
+
self,
|
113 |
+
hidden_size: int = 4096,
|
114 |
+
num_hidden_layers: int = 32,
|
115 |
+
rms_norm_eps: float = 1e-6,
|
116 |
+
tie_word_embeddings: bool = True,
|
117 |
+
# Attention
|
118 |
+
num_attention_heads: int = 32,
|
119 |
+
num_key_value_heads: int = 4,
|
120 |
+
hidden_size_per_head: int = 128,
|
121 |
+
max_position_embeddings: int = 2048,
|
122 |
+
attention_window_size: int = 2048,
|
123 |
+
full_attention_idx: list[int] | None = None,
|
124 |
+
# Mamba
|
125 |
+
mamba_d_state: int = 64,
|
126 |
+
mamba_d_conv: int = 4,
|
127 |
+
mamba_num_heads: int = 64,
|
128 |
+
mamba_step: int = 2,
|
129 |
+
mamba_chunk_size: int = 256,
|
130 |
+
mamba_enabled: bool = True,
|
131 |
+
# MLP
|
132 |
+
intermediate_size: int = 13312,
|
133 |
+
# Tokenizer
|
134 |
+
vocab_size: int = 32000,
|
135 |
+
tokenizer_class: str = "PlamoTokenizer",
|
136 |
+
pad_token_id: Optional[int] = None,
|
137 |
+
bos_token_id: int = 1,
|
138 |
+
eos_token_id: int = 2,
|
139 |
+
# Multimodal
|
140 |
+
image_token_id: Optional[int] = None,
|
141 |
+
image_feature_size: Optional[int] = None,
|
142 |
+
image_proj_type: Literal["linear", "mlp"] = "linear",
|
143 |
+
# FP8
|
144 |
+
linear_type: LinearType = LinearType.Normal,
|
145 |
+
fp8_accum_dtype: Optional[str] = None,
|
146 |
+
# Evaluation
|
147 |
+
eval_attention_n_bit: Optional[int] = None,
|
148 |
+
eval_mlp_n_bit: Optional[int] = None,
|
149 |
+
use_cache: bool = True,
|
150 |
+
**kwargs: Any,
|
151 |
+
) -> None:
|
152 |
+
# max_position_embeddings is often used to determine the max length during inference,
|
153 |
+
# but samba should have extrapolation abilities
|
154 |
+
self.max_position_embeddings = max(10 * 1024 * 1024, max_position_embeddings)
|
155 |
+
self.hidden_size = hidden_size
|
156 |
+
self.rms_norm_eps = rms_norm_eps
|
157 |
+
|
158 |
+
self.num_hidden_layers = num_hidden_layers
|
159 |
+
self.num_attention_heads = num_attention_heads
|
160 |
+
self.hidden_size_per_head = hidden_size_per_head
|
161 |
+
self.num_key_value_heads = num_key_value_heads
|
162 |
+
self.attention_window_size = attention_window_size
|
163 |
+
self.full_attention_idx = full_attention_idx if full_attention_idx is not None else []
|
164 |
+
|
165 |
+
self.mamba_d_state = mamba_d_state
|
166 |
+
self.mamba_d_conv = mamba_d_conv
|
167 |
+
self.mamba_num_heads = mamba_num_heads
|
168 |
+
self.mamba_step = mamba_step
|
169 |
+
self.mamba_chunk_size = mamba_chunk_size
|
170 |
+
self.mamba_enabled = mamba_enabled
|
171 |
+
|
172 |
+
self.intermediate_size = intermediate_size
|
173 |
+
|
174 |
+
self.vocab_size = vocab_size
|
175 |
+
|
176 |
+
self.image_token_id = image_token_id
|
177 |
+
self.image_feature_size = image_feature_size
|
178 |
+
self.image_proj_type = image_proj_type
|
179 |
+
|
180 |
+
self.linear_type = linear_type
|
181 |
+
self.fp8_accum_dtype = fp8_accum_dtype
|
182 |
+
|
183 |
+
self.eval_attention_n_bit = eval_attention_n_bit
|
184 |
+
self.eval_mlp_n_bit = eval_mlp_n_bit
|
185 |
+
self.use_cache = use_cache
|
186 |
+
|
187 |
+
# fields for vLLM
|
188 |
+
self.sliding_window = attention_window_size
|
189 |
+
|
190 |
+
super().__init__(
|
191 |
+
tokenizer_class=tokenizer_class,
|
192 |
+
pad_token_id=pad_token_id,
|
193 |
+
bos_token_id=bos_token_id,
|
194 |
+
eos_token_id=eos_token_id,
|
195 |
+
tie_word_embeddings=tie_word_embeddings,
|
196 |
+
**kwargs,
|
197 |
+
)
|
198 |
+
|
199 |
+
|
200 |
+
class PlamoAttentionCache(torch.nn.Module):
|
201 |
+
def __init__(self, key: torch.Tensor, value: torch.Tensor) -> None:
|
202 |
+
super().__init__()
|
203 |
+
B, nh, L, c = key.shape
|
204 |
+
assert len(value.shape) == 4
|
205 |
+
assert value.shape[0] == B
|
206 |
+
assert value.shape[2] == L
|
207 |
+
self.register_parameter("key", torch.nn.Parameter(key, requires_grad=False))
|
208 |
+
self.register_parameter("value", torch.nn.Parameter(value, requires_grad=False))
|
209 |
+
|
210 |
+
|
211 |
+
class PlamoMambaCache(torch.nn.Module):
|
212 |
+
def __init__(self, conv_state: torch.Tensor, ssm_state: torch.Tensor) -> None:
|
213 |
+
super().__init__()
|
214 |
+
# conv_state: [B, C, d_conv]
|
215 |
+
# ssm_state: [B, nhead, nchanel_per_head, d_state]
|
216 |
+
assert len(conv_state.shape) == 3
|
217 |
+
assert len(ssm_state.shape) == 4
|
218 |
+
assert conv_state.shape[0] == ssm_state.shape[0]
|
219 |
+
self.register_parameter("conv_state", torch.nn.Parameter(conv_state, requires_grad=False))
|
220 |
+
self.register_parameter("ssm_state", torch.nn.Parameter(ssm_state, requires_grad=False))
|
221 |
+
|
222 |
+
|
223 |
+
PlamoLayerCache = PlamoAttentionCache | PlamoMambaCache
|
224 |
+
|
225 |
+
|
226 |
+
class PlamoCache(torch.nn.Module):
|
227 |
+
"""
|
228 |
+
stores states of the model for fast decoding.
|
229 |
+
`transformers` uses `transformers.Cache` for this purpose, but the interface and variable names are
|
230 |
+
deeply dependent on Transformers architecture (e.g., `key_states`) and it is difficult to use
|
231 |
+
other architectures (e.g., Mamba).
|
232 |
+
This class provides a similar interface to `transformers.Cache`, but is designed to also handle
|
233 |
+
the state of Mamba properly.
|
234 |
+
"""
|
235 |
+
|
236 |
+
def __init__(self, config: PlamoConfig) -> None:
|
237 |
+
super().__init__()
|
238 |
+
self.config = config
|
239 |
+
self.cache = torch.nn.ModuleList([None for _ in range(config.num_hidden_layers)]) # type: ignore
|
240 |
+
|
241 |
+
def append_kv(self, key: torch.Tensor, value: torch.Tensor, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor]:
|
242 |
+
c = self.cache[layer_idx]
|
243 |
+
assert isinstance(c, PlamoAttentionCache)
|
244 |
+
|
245 |
+
def _validate(cache: torch.Tensor, new_tensor: torch.Tensor) -> None:
|
246 |
+
assert len(cache.shape) == 4
|
247 |
+
assert len(new_tensor.shape) == 4
|
248 |
+
assert cache.shape[0] == new_tensor.shape[0]
|
249 |
+
assert cache.shape[1] == new_tensor.shape[1]
|
250 |
+
assert cache.shape[3] == new_tensor.shape[3]
|
251 |
+
|
252 |
+
_validate(c.key, key)
|
253 |
+
_validate(c.value, value)
|
254 |
+
assert key.shape[2] == value.shape[2]
|
255 |
+
return torch.cat([c.key, key], dim=2), torch.cat([c.value, value], dim=2)
|
256 |
+
|
257 |
+
def update_attention(
|
258 |
+
self, key_states: torch.Tensor, value_states: torch.Tensor, layer_idx: int
|
259 |
+
) -> PlamoAttentionCache:
|
260 |
+
if self.cache[layer_idx] is None:
|
261 |
+
self.cache[layer_idx] = PlamoAttentionCache(key_states, value_states)
|
262 |
+
else:
|
263 |
+
full_attn = layer_idx in self.config.full_attention_idx
|
264 |
+
window_size = self.config.attention_window_size
|
265 |
+
c = self.cache[layer_idx]
|
266 |
+
assert isinstance(c, PlamoAttentionCache)
|
267 |
+
k, v = self.append_kv(key_states, value_states, layer_idx)
|
268 |
+
if full_attn:
|
269 |
+
c.key.data = k
|
270 |
+
c.value.data = v
|
271 |
+
else:
|
272 |
+
c.key.data = k[:, :, -window_size:, :]
|
273 |
+
c.value.data = v[:, :, -window_size:, :]
|
274 |
+
return self.cache[layer_idx] # type: ignore
|
275 |
+
|
276 |
+
def update_mamba(self, conv_state: torch.Tensor, ssm_state: torch.Tensor, layer_idx: int) -> PlamoMambaCache:
|
277 |
+
if self.cache[layer_idx] is None:
|
278 |
+
self.cache[layer_idx] = PlamoMambaCache(conv_state, ssm_state)
|
279 |
+
else:
|
280 |
+
c = self.cache[layer_idx]
|
281 |
+
assert isinstance(c, PlamoMambaCache)
|
282 |
+
assert c.conv_state.shape == conv_state.shape
|
283 |
+
assert c.ssm_state.shape == ssm_state.shape
|
284 |
+
c.conv_state.data = conv_state
|
285 |
+
c.ssm_state.data = ssm_state
|
286 |
+
return self.cache[layer_idx] # type: ignore
|
287 |
+
|
288 |
+
def __getitem__(self, layer_idx: int) -> PlamoLayerCache | None:
|
289 |
+
assert layer_idx < len(self.cache)
|
290 |
+
layer_cache = self.cache[layer_idx]
|
291 |
+
return layer_cache # type: ignore
|
292 |
+
|
293 |
+
def __len__(self) -> int:
|
294 |
+
return len(self.cache)
|
295 |
+
|
296 |
+
def get_seq_length(self, layer_idx: Optional[int] = None) -> int:
|
297 |
+
if layer_idx is not None:
|
298 |
+
c = self.cache[layer_idx]
|
299 |
+
assert isinstance(c, PlamoAttentionCache)
|
300 |
+
return c.key.shape[2] # type: ignore
|
301 |
+
|
302 |
+
sequence_length: int | None = None
|
303 |
+
for layer_cache in self.cache:
|
304 |
+
if isinstance(layer_cache, PlamoAttentionCache):
|
305 |
+
sequence_length = (
|
306 |
+
max(layer_cache.key.shape[2], sequence_length)
|
307 |
+
if sequence_length is not None
|
308 |
+
else layer_cache.key.shape[2]
|
309 |
+
)
|
310 |
+
assert sequence_length is not None
|
311 |
+
return sequence_length
|
312 |
+
|
313 |
+
def get_max_length(self) -> int | None:
|
314 |
+
return None
|
315 |
+
|
316 |
+
def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int:
|
317 |
+
"""Given the sequence length of the new inputs, returns the usable length of the cache."""
|
318 |
+
# Cache without size limit -> all cache is usable
|
319 |
+
# Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache
|
320 |
+
# length, we will need to evict part of the cache (and thus not all cache is usable)
|
321 |
+
max_length = self.get_max_length()
|
322 |
+
previous_seq_length = self.get_seq_length(layer_idx)
|
323 |
+
if max_length is not None and previous_seq_length + new_seq_length > max_length:
|
324 |
+
return max_length - new_seq_length
|
325 |
+
return previous_seq_length
|
326 |
+
|
327 |
+
def reorder_cache(self, beam_idx: torch.Tensor) -> None:
|
328 |
+
def _mamba(cache: PlamoMambaCache) -> PlamoMambaCache:
|
329 |
+
return PlamoMambaCache(
|
330 |
+
conv_state=cache.conv_state.index_select(0, beam_idx),
|
331 |
+
ssm_state=cache.ssm_state.index_select(0, beam_idx),
|
332 |
+
)
|
333 |
+
|
334 |
+
def _attention(cache: PlamoAttentionCache) -> PlamoAttentionCache:
|
335 |
+
return PlamoAttentionCache(
|
336 |
+
key=cache.key.index_select(0, beam_idx),
|
337 |
+
value=cache.value.index_select(0, beam_idx),
|
338 |
+
)
|
339 |
+
|
340 |
+
for i in range(len(self.cache)):
|
341 |
+
if self.cache[i] is None:
|
342 |
+
continue
|
343 |
+
layer_cache = self.cache[i]
|
344 |
+
if isinstance(layer_cache, PlamoMambaCache):
|
345 |
+
self.cache[i] = _mamba(layer_cache)
|
346 |
+
else:
|
347 |
+
assert isinstance(layer_cache, PlamoAttentionCache)
|
348 |
+
self.cache[i] = _attention(layer_cache)
|
349 |
+
|
350 |
+
@property
|
351 |
+
def seen_tokens(self) -> int | None:
|
352 |
+
return None
|
353 |
+
|
354 |
+
|
355 |
+
class DecoderInput(NamedTuple):
|
356 |
+
hidden_states: torch.Tensor
|
357 |
+
attention_mask: Optional[torch.Tensor] = None
|
358 |
+
past_states: Optional[PlamoCache] = None
|
359 |
+
output_hidden_states: Optional[bool] = False
|
360 |
+
output_attentions: Optional[bool] = False
|
361 |
+
gradient_checkpointing: bool = False
|
362 |
+
input_ids: Optional[torch.Tensor] = None
|
363 |
+
|
364 |
+
|
365 |
+
class DecoderOutput(NamedTuple):
|
366 |
+
hidden_states: torch.Tensor
|
367 |
+
all_hidden_states: Optional[Tuple[torch.Tensor, ...]]
|
368 |
+
all_self_attns: Optional[Tuple[torch.Tensor, ...]]
|
369 |
+
|
370 |
+
|
371 |
+
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
|
372 |
+
def _make_causal_mask(
|
373 |
+
input_ids_shape: Tuple[int, int], dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
|
374 |
+
) -> torch.Tensor:
|
375 |
+
"""
|
376 |
+
Make causal mask used for bi-directional self-attention.
|
377 |
+
"""
|
378 |
+
bsz, tgt_len = input_ids_shape
|
379 |
+
mask = torch.full((tgt_len, tgt_len), float("-inf"), device=device)
|
380 |
+
mask_cond = torch.arange(mask.size(-1), device=device)
|
381 |
+
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
382 |
+
mask = mask.to(dtype)
|
383 |
+
|
384 |
+
if past_key_values_length > 0:
|
385 |
+
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
|
386 |
+
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
|
387 |
+
|
388 |
+
|
389 |
+
# Copied from transformers.models.bart.modeling_bart._expand_mask
|
390 |
+
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None) -> torch.Tensor:
|
391 |
+
"""
|
392 |
+
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
393 |
+
"""
|
394 |
+
bsz, src_len = mask.size()
|
395 |
+
tgt_len = tgt_len if tgt_len is not None else src_len
|
396 |
+
|
397 |
+
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
|
398 |
+
|
399 |
+
inverted_mask = 1.0 - expanded_mask
|
400 |
+
|
401 |
+
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), float("-inf")) # type: ignore
|
402 |
+
|
403 |
+
|
404 |
+
def _rms_norm(
|
405 |
+
hidden_states: torch.Tensor, weight: Optional[torch.Tensor], eps: float, offset: float = 1.0
|
406 |
+
) -> torch.Tensor:
|
407 |
+
input_dtype = hidden_states.dtype
|
408 |
+
hidden_states = hidden_states.to(torch.float32)
|
409 |
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
410 |
+
hidden_states = hidden_states * torch.rsqrt(variance + eps)
|
411 |
+
hidden_states = hidden_states.to(input_dtype)
|
412 |
+
if weight is not None:
|
413 |
+
hidden_states = (offset + weight) * hidden_states
|
414 |
+
return hidden_states
|
415 |
+
|
416 |
+
|
417 |
+
class RMSNorm(nn.Module):
|
418 |
+
def __init__(
|
419 |
+
self,
|
420 |
+
hidden_size: int,
|
421 |
+
eps: float = 1e-6,
|
422 |
+
offset: float = 1.0,
|
423 |
+
device: Optional[Union[torch.device, str]] = None,
|
424 |
+
) -> None:
|
425 |
+
super().__init__()
|
426 |
+
self.weight = nn.Parameter(torch.zeros(hidden_size, device=device))
|
427 |
+
self.variance_epsilon = eps
|
428 |
+
self.offset = offset
|
429 |
+
|
430 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
431 |
+
return _rms_norm(hidden_states, self.weight, self.variance_epsilon, offset=self.offset)
|
432 |
+
|
433 |
+
|
434 |
+
def get_initial_dt_bias(num_heads: int) -> torch.Tensor:
|
435 |
+
dt_min = 0.001
|
436 |
+
dt_max = 0.1
|
437 |
+
dt = torch.exp(torch.rand(num_heads) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min))
|
438 |
+
dt = torch.clamp(dt, 1e-4)
|
439 |
+
inv_dt = dt + torch.log(-torch.expm1(-dt))
|
440 |
+
return inv_dt
|
441 |
+
|
442 |
+
|
443 |
+
def get_initial_A(num_heads: int) -> torch.Tensor:
|
444 |
+
A = torch.arange(1, num_heads + 1, dtype=torch.float32)
|
445 |
+
return torch.log(A)
|
446 |
+
|
447 |
+
|
448 |
+
def _bf16_supported_in_triton() -> bool:
|
449 |
+
# newer torch (2.2.0 and later?) supports bfloat16 even when using Voltas
|
450 |
+
# but triton cannot compile bf16 kernels for Volta
|
451 |
+
major, _ = torch.cuda.get_device_capability()
|
452 |
+
return major >= 8
|
453 |
+
|
454 |
+
|
455 |
+
def _get_trition_dtype(dtype: torch.dtype) -> torch.dtype:
|
456 |
+
if dtype != torch.bfloat16:
|
457 |
+
return dtype
|
458 |
+
if _bf16_supported_in_triton():
|
459 |
+
return dtype
|
460 |
+
return torch.float32
|
461 |
+
|
462 |
+
|
463 |
+
def ssd_update_state(
|
464 |
+
ssm_state: torch.Tensor,
|
465 |
+
x: torch.Tensor,
|
466 |
+
dt: torch.Tensor,
|
467 |
+
A: torch.Tensor,
|
468 |
+
B: torch.Tensor,
|
469 |
+
C: torch.Tensor,
|
470 |
+
D: torch.Tensor,
|
471 |
+
z: torch.Tensor,
|
472 |
+
dt_bias: torch.Tensor,
|
473 |
+
dt_softplus: bool,
|
474 |
+
) -> torch.Tensor:
|
475 |
+
assert ssm_state.dtype == torch.float32
|
476 |
+
if dt.is_cuda:
|
477 |
+
dtype = _get_trition_dtype(x.dtype)
|
478 |
+
else:
|
479 |
+
dtype = x.dtype
|
480 |
+
if dt.is_cuda:
|
481 |
+
f = mamba_ssm.ops.triton.selective_state_update.selective_state_update
|
482 |
+
else:
|
483 |
+
f = mamba_ssm.ops.triton.selective_state_update.selective_state_update_ref
|
484 |
+
|
485 |
+
hidden_size_per_head = x.shape[-1]
|
486 |
+
d_state = B.shape[-1]
|
487 |
+
A = A[:, None, None].expand(-1, hidden_size_per_head, d_state).float()
|
488 |
+
dt = dt[..., None].expand(-1, -1, hidden_size_per_head)
|
489 |
+
dt_bias = dt_bias[:, None].expand(-1, hidden_size_per_head)
|
490 |
+
D = D[:, None].expand(-1, hidden_size_per_head)
|
491 |
+
assert ssm_state.dtype == torch.float32
|
492 |
+
out = f(
|
493 |
+
ssm_state,
|
494 |
+
x.to(dtype),
|
495 |
+
dt.to(dtype),
|
496 |
+
A.float(),
|
497 |
+
B.to(dtype),
|
498 |
+
C.to(dtype),
|
499 |
+
D.float(),
|
500 |
+
z.to(dtype),
|
501 |
+
dt_bias.float(),
|
502 |
+
dt_softplus=dt_softplus,
|
503 |
+
)
|
504 |
+
return out[:, None] # type: ignore
|
505 |
+
|
506 |
+
|
507 |
+
def _ssd_chunk_scan_combined_naive(
|
508 |
+
x: torch.Tensor,
|
509 |
+
dt: torch.Tensor,
|
510 |
+
A: torch.Tensor,
|
511 |
+
B: torch.Tensor,
|
512 |
+
C: torch.Tensor,
|
513 |
+
D: torch.Tensor,
|
514 |
+
z: torch.Tensor,
|
515 |
+
dt_bias: torch.Tensor,
|
516 |
+
dt_softplus: bool,
|
517 |
+
seq_idx: torch.Tensor | None,
|
518 |
+
ssm_state: torch.Tensor,
|
519 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
520 |
+
assert ssm_state.dtype == torch.float32
|
521 |
+
length = x.shape[1]
|
522 |
+
ys = []
|
523 |
+
for i in range(length):
|
524 |
+
if i != 0 and seq_idx is not None:
|
525 |
+
ssm_state = torch.where(
|
526 |
+
(seq_idx[:, i - 1] != seq_idx[:, i])[:, None, None, None],
|
527 |
+
torch.zeros_like(ssm_state),
|
528 |
+
ssm_state,
|
529 |
+
)
|
530 |
+
y = ssd_update_state(
|
531 |
+
ssm_state,
|
532 |
+
x[:, i],
|
533 |
+
dt[:, i],
|
534 |
+
A,
|
535 |
+
B[:, i],
|
536 |
+
C[:, i],
|
537 |
+
D,
|
538 |
+
z=z[:, i],
|
539 |
+
dt_bias=dt_bias,
|
540 |
+
dt_softplus=dt_softplus,
|
541 |
+
)
|
542 |
+
ys.append(y)
|
543 |
+
return torch.cat(ys, dim=1), ssm_state
|
544 |
+
|
545 |
+
|
546 |
+
def ssd_chunk_scan_combined(
|
547 |
+
x: torch.Tensor,
|
548 |
+
dt: torch.Tensor,
|
549 |
+
A: torch.Tensor,
|
550 |
+
B: torch.Tensor,
|
551 |
+
C: torch.Tensor,
|
552 |
+
chunk_size: int,
|
553 |
+
D: torch.Tensor,
|
554 |
+
z: torch.Tensor,
|
555 |
+
dt_bias: torch.Tensor,
|
556 |
+
dt_softplus: bool,
|
557 |
+
return_final_states: bool,
|
558 |
+
seq_idx: torch.Tensor | None,
|
559 |
+
ssm_state: torch.Tensor | None,
|
560 |
+
) -> tuple[torch.Tensor, torch.Tensor] | torch.Tensor:
|
561 |
+
if seq_idx is not None:
|
562 |
+
assert seq_idx.dtype == torch.int32
|
563 |
+
assert ssm_state is None
|
564 |
+
assert not return_final_states
|
565 |
+
if ssm_state is not None:
|
566 |
+
assert ssm_state.dtype == torch.float32
|
567 |
+
assert seq_idx is None
|
568 |
+
|
569 |
+
length = x.shape[1]
|
570 |
+
|
571 |
+
"""
|
572 |
+
state will be updates by following:
|
573 |
+
```
|
574 |
+
dt = softplus(dt)
|
575 |
+
dA = exp(dt * A)
|
576 |
+
state_next = state * dA + dB * x
|
577 |
+
```
|
578 |
+
|
579 |
+
To avoid updating state, we set dt to -inf and x to 0
|
580 |
+
because `softplus(-inf) = 0` and `exp(0) = 1`
|
581 |
+
"""
|
582 |
+
if dt.is_cuda:
|
583 |
+
pad = (chunk_size - length % chunk_size) % chunk_size
|
584 |
+
x = torch.nn.functional.pad(x, pad=[0, 0, 0, 0, pad, 0], value=0.0)
|
585 |
+
dt = torch.nn.functional.pad(dt, pad=[0, 0, pad, 0], value=float("-inf"))
|
586 |
+
B = torch.nn.functional.pad(B, pad=[0, 0, 0, 0, pad, 0], value=0.0)
|
587 |
+
C = torch.nn.functional.pad(C, pad=[0, 0, 0, 0, pad, 0], value=0.0)
|
588 |
+
z = torch.nn.functional.pad(z, pad=[0, 0, 0, 0, pad, 0], value=0.0)
|
589 |
+
if seq_idx is not None:
|
590 |
+
seq_idx = torch.nn.functional.pad(seq_idx, pad=[pad, 0], value=0)
|
591 |
+
|
592 |
+
length = x.shape[1]
|
593 |
+
assert length % chunk_size == 0, (length, chunk_size)
|
594 |
+
|
595 |
+
dtype = _get_trition_dtype(x.dtype)
|
596 |
+
out = mamba_ssm.ops.triton.ssd_combined.mamba_chunk_scan_combined( # type: ignore
|
597 |
+
x.to(dtype),
|
598 |
+
dt.to(dtype),
|
599 |
+
A.float(),
|
600 |
+
B.to(dtype),
|
601 |
+
C.to(dtype),
|
602 |
+
chunk_size,
|
603 |
+
D=D.float(),
|
604 |
+
z=z.to(dtype),
|
605 |
+
initial_states=ssm_state,
|
606 |
+
dt_bias=dt_bias.float(),
|
607 |
+
dt_softplus=dt_softplus,
|
608 |
+
seq_idx=seq_idx,
|
609 |
+
return_final_states=return_final_states,
|
610 |
+
)
|
611 |
+
if return_final_states:
|
612 |
+
return out[0][:, pad:], out[1]
|
613 |
+
else:
|
614 |
+
assert isinstance(out, torch.Tensor)
|
615 |
+
return out[:, pad:]
|
616 |
+
else:
|
617 |
+
if ssm_state is None:
|
618 |
+
bsize, _, num_heads, channel = x.shape
|
619 |
+
state = B.shape[-1]
|
620 |
+
ssm_state = torch.zeros(bsize, num_heads, channel, state, dtype=torch.float32, device=x.device)
|
621 |
+
tmp = _ssd_chunk_scan_combined_naive(
|
622 |
+
x, dt, A, B, C, D, z=z, dt_bias=dt_bias, dt_softplus=dt_softplus, seq_idx=seq_idx, ssm_state=ssm_state
|
623 |
+
)
|
624 |
+
if return_final_states:
|
625 |
+
return tmp
|
626 |
+
else:
|
627 |
+
return tmp[0]
|
628 |
+
|
629 |
+
|
630 |
+
def _causal_conv1d(
|
631 |
+
conv_state: torch.Tensor | None, weight: torch.Tensor, x: torch.Tensor, seq_idx: torch.Tensor | None
|
632 |
+
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
633 |
+
dtype = x.dtype
|
634 |
+
if conv_state is not None:
|
635 |
+
dtype = conv_state.dtype
|
636 |
+
assert seq_idx is None
|
637 |
+
if seq_idx is not None:
|
638 |
+
assert seq_idx.dtype == torch.int32
|
639 |
+
assert conv_state is None
|
640 |
+
weight = weight.to(dtype)
|
641 |
+
x = x.to(dtype)
|
642 |
+
|
643 |
+
return_final_states = conv_state is not None
|
644 |
+
if weight.is_cuda:
|
645 |
+
if x.stride(1) != 1:
|
646 |
+
# to channel-last format
|
647 |
+
x = x.transpose(-1, -2).contiguous().transpose(-1, -2)
|
648 |
+
if conv_state is not None:
|
649 |
+
if conv_state.stride(1) != 1:
|
650 |
+
# to channel-last format
|
651 |
+
conv_state = conv_state.transpose(-1, -2).contiguous().transpose(-1, -2)
|
652 |
+
tmp = causal_conv1d.causal_conv1d_fn(
|
653 |
+
x=x,
|
654 |
+
weight=weight[:, 0, :],
|
655 |
+
initial_states=conv_state,
|
656 |
+
return_final_states=conv_state is not None,
|
657 |
+
activation="silu",
|
658 |
+
seq_idx=seq_idx,
|
659 |
+
)
|
660 |
+
if conv_state is not None:
|
661 |
+
x, conv_state = tmp
|
662 |
+
else:
|
663 |
+
x = tmp
|
664 |
+
else:
|
665 |
+
if conv_state is None:
|
666 |
+
bsize = x.shape[0]
|
667 |
+
dim = weight.shape[0]
|
668 |
+
d_conv = weight.shape[-1]
|
669 |
+
conv_state = torch.zeros(bsize, dim, d_conv - 1, dtype=x.dtype, device=x.device)
|
670 |
+
length = x.shape[-1]
|
671 |
+
out = torch.zeros_like(x)
|
672 |
+
for i in range(length):
|
673 |
+
if i != 0 and seq_idx is not None:
|
674 |
+
conv_state = torch.where(
|
675 |
+
(seq_idx[:, i - 1] != seq_idx[:, i])[:, None, None],
|
676 |
+
torch.zeros_like(conv_state),
|
677 |
+
conv_state,
|
678 |
+
)
|
679 |
+
out[:, :, i : i + 1], conv_state = _causal_conv1d_update(conv_state, weight, x[:, :, i : i + 1])
|
680 |
+
x = out
|
681 |
+
if return_final_states:
|
682 |
+
return x, conv_state
|
683 |
+
else:
|
684 |
+
return x, None
|
685 |
+
|
686 |
+
|
687 |
+
def _causal_conv1d_update(
|
688 |
+
conv_state: torch.Tensor, weight: torch.Tensor, xBC: torch.Tensor
|
689 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
690 |
+
dtype = conv_state.dtype
|
691 |
+
xBC = xBC.to(dtype)
|
692 |
+
weight = weight.to(dtype)
|
693 |
+
if conv_state.is_cuda:
|
694 |
+
x = causal_conv1d.causal_conv1d_update(
|
695 |
+
x=xBC,
|
696 |
+
conv_state=conv_state,
|
697 |
+
weight=weight[:, 0, :],
|
698 |
+
activation="silu",
|
699 |
+
)
|
700 |
+
return x, conv_state
|
701 |
+
else:
|
702 |
+
x = causal_conv1d.causal_conv1d_update_ref(
|
703 |
+
x=xBC,
|
704 |
+
conv_state=conv_state,
|
705 |
+
weight=weight[:, 0, :],
|
706 |
+
activation="silu",
|
707 |
+
)
|
708 |
+
return x, conv_state
|
709 |
+
|
710 |
+
|
711 |
+
class Mamba(torch.nn.Module):
|
712 |
+
def __init__(self, config: PlamoConfig, layer_idx: int) -> None:
|
713 |
+
super().__init__()
|
714 |
+
self.config = config
|
715 |
+
self.layer_idx = layer_idx
|
716 |
+
self.hidden_size = config.hidden_size
|
717 |
+
self.d_state = config.mamba_d_state
|
718 |
+
self.d_conv = config.mamba_d_conv
|
719 |
+
self.chunk_size = config.mamba_chunk_size
|
720 |
+
self.num_heads = config.mamba_num_heads
|
721 |
+
# TODO add mamba_hidden_size_per_head config (?)
|
722 |
+
self.hidden_size_per_head = config.hidden_size_per_head
|
723 |
+
|
724 |
+
self.intermediate_size = self.num_heads * self.hidden_size_per_head
|
725 |
+
|
726 |
+
self.in_proj = torch.nn.Linear(self.hidden_size, 2 * self.intermediate_size, bias=False)
|
727 |
+
self.conv1d = torch.nn.Conv1d(
|
728 |
+
in_channels=self.intermediate_size,
|
729 |
+
out_channels=self.intermediate_size,
|
730 |
+
bias=False, # TODO the original implementation uses bias
|
731 |
+
kernel_size=self.d_conv,
|
732 |
+
groups=self.intermediate_size,
|
733 |
+
padding=0,
|
734 |
+
)
|
735 |
+
self.dt_dim = max(64, self.hidden_size // 16)
|
736 |
+
# Notes:
|
737 |
+
# Mamba2 removes this linear projection for simplicity (Figure 6 in the paper),
|
738 |
+
# but it may degrade the ability of content-length extrapolation.
|
739 |
+
self.bcdt_proj = torch.nn.Linear(
|
740 |
+
self.intermediate_size,
|
741 |
+
self.dt_dim + 2 * self.d_state,
|
742 |
+
bias=False,
|
743 |
+
)
|
744 |
+
self.dt_proj = torch.nn.Linear(self.dt_dim, self.num_heads, bias=False)
|
745 |
+
|
746 |
+
self.dt_bias = torch.nn.Parameter(get_initial_dt_bias(self.num_heads))
|
747 |
+
self.A_log = torch.nn.Parameter(get_initial_A(self.num_heads))
|
748 |
+
self.D = torch.nn.Parameter(torch.ones(self.num_heads))
|
749 |
+
|
750 |
+
# TODO norm weight before gating like Mamba2
|
751 |
+
self.dt_norm_weight = torch.nn.Parameter(torch.ones(self.dt_dim))
|
752 |
+
self.B_norm_weight = torch.nn.Parameter(torch.ones(self.d_state))
|
753 |
+
self.C_norm_weight = torch.nn.Parameter(torch.ones(self.d_state))
|
754 |
+
|
755 |
+
self.out_proj = torch.nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
756 |
+
|
757 |
+
def _no_weight_decay_param_names(self) -> set[str]:
|
758 |
+
return set(["D", "dt_bias", "A_log"])
|
759 |
+
|
760 |
+
def forward(
|
761 |
+
self,
|
762 |
+
hidden_states: torch.Tensor,
|
763 |
+
attention_mask: Optional[torch.Tensor] = None,
|
764 |
+
past_states: Optional[PlamoCache] = None,
|
765 |
+
) -> Tuple[torch.Tensor, Optional[PlamoCache]]:
|
766 |
+
bsize, length, _ = hidden_states.shape
|
767 |
+
is_update = length == 1 and past_states is not None
|
768 |
+
|
769 |
+
bool_mask: torch.Tensor | None = None
|
770 |
+
seq_idx: torch.Tensor | None = None
|
771 |
+
if attention_mask is not None:
|
772 |
+
if len(attention_mask.shape) == 2:
|
773 |
+
attention_mask = attention_mask[None, None].expand(bsize, 1, -1, -1)
|
774 |
+
assert len(attention_mask.shape) == 4
|
775 |
+
|
776 |
+
if past_states is None:
|
777 |
+
# TODO: support seq_idx with cache
|
778 |
+
bool_mask_4d = attention_mask == 0
|
779 |
+
is_first_token = _is_first_token(bool_mask_4d)[:, 0, :]
|
780 |
+
seq_idx = torch.cumsum(is_first_token, dim=-1) - 1
|
781 |
+
seq_idx = seq_idx.to(torch.int32)
|
782 |
+
|
783 |
+
# `generate` function creates attention mask that contains past tokens,
|
784 |
+
# but mamba does not use them
|
785 |
+
attention_mask = attention_mask[:, 0, -length:, -length:]
|
786 |
+
bool_mask = torch.diagonal(attention_mask, dim1=-2, dim2=-1) == 0
|
787 |
+
|
788 |
+
conv_state: torch.Tensor | None
|
789 |
+
ssm_state: torch.Tensor | None
|
790 |
+
if past_states is None:
|
791 |
+
conv_state = None
|
792 |
+
ssm_state = None
|
793 |
+
elif past_states[self.layer_idx] is None:
|
794 |
+
conv_state = torch.zeros(
|
795 |
+
bsize, self.intermediate_size, self.d_conv - 1, dtype=hidden_states.dtype, device=hidden_states.device
|
796 |
+
)
|
797 |
+
ssm_state = torch.zeros(
|
798 |
+
bsize,
|
799 |
+
self.num_heads,
|
800 |
+
self.hidden_size_per_head,
|
801 |
+
self.d_state,
|
802 |
+
dtype=torch.float32,
|
803 |
+
device=hidden_states.device,
|
804 |
+
)
|
805 |
+
else:
|
806 |
+
c = past_states[self.layer_idx]
|
807 |
+
assert isinstance(c, PlamoMambaCache)
|
808 |
+
conv_state = c.conv_state
|
809 |
+
ssm_state = c.ssm_state
|
810 |
+
|
811 |
+
zx = self.in_proj(hidden_states)
|
812 |
+
zx = zx.reshape(bsize, length, self.num_heads, -1)
|
813 |
+
# z: (bsize, length, num_heads, hidden_size_per_head)
|
814 |
+
# x: (bsize, length, num_heads, hidden_size_per_head)
|
815 |
+
z, x = torch.split(zx, [self.hidden_size_per_head, self.hidden_size_per_head], dim=-1)
|
816 |
+
|
817 |
+
# conv
|
818 |
+
x = x.reshape(bsize, length, -1).transpose(1, 2) # (bsize, intermediate_size, length)
|
819 |
+
if bool_mask is not None:
|
820 |
+
x = torch.where(bool_mask[:, None, :], x, 0.0)
|
821 |
+
if is_update:
|
822 |
+
assert conv_state is not None
|
823 |
+
x, conv_state = _causal_conv1d_update(conv_state, self.conv1d.weight, x)
|
824 |
+
else:
|
825 |
+
x, conv_state = _causal_conv1d(conv_state, self.conv1d.weight, x, seq_idx=seq_idx)
|
826 |
+
x = x.to(dtype=hidden_states.dtype)
|
827 |
+
x = x.transpose(1, 2) # (bsize, length, intermediate_size)
|
828 |
+
x = x.reshape(bsize, length, -1)
|
829 |
+
# x: (bsize, length, num_heads, hidden_size_per_head)
|
830 |
+
# B: (bsize, length, 1, d_state)
|
831 |
+
# C: (bsize, length, 1, d_state)
|
832 |
+
# dt: (bsize, length, dt_dim)
|
833 |
+
BCdt = self.bcdt_proj(x)
|
834 |
+
x = x.reshape(bsize, length, self.num_heads, -1)
|
835 |
+
B, C, dt = torch.split(BCdt, [self.d_state, self.d_state, self.dt_dim], dim=-1)
|
836 |
+
B = B[:, :, None, :]
|
837 |
+
C = C[:, :, None, :]
|
838 |
+
|
839 |
+
A = -torch.exp(self.A_log.float()) # (num_heads,)
|
840 |
+
dt = _rms_norm(dt, None, self.config.rms_norm_eps) * self.dt_norm_weight[None, None, :]
|
841 |
+
B = _rms_norm(B, None, self.config.rms_norm_eps) * self.B_norm_weight[None, None, None, :]
|
842 |
+
C = _rms_norm(C, None, self.config.rms_norm_eps) * self.C_norm_weight[None, None, None, :]
|
843 |
+
|
844 |
+
# (bsize, length, num_heads, 1)
|
845 |
+
dt = self.dt_proj(dt)[..., None]
|
846 |
+
|
847 |
+
# TODO it may not be required
|
848 |
+
B = B.expand(-1, -1, self.num_heads, -1)
|
849 |
+
C = C.expand(-1, -1, self.num_heads, -1)
|
850 |
+
|
851 |
+
if bool_mask is not None:
|
852 |
+
"""
|
853 |
+
state will be updates by following:
|
854 |
+
```
|
855 |
+
dt = softplus(dt)
|
856 |
+
dA = exp(dt * A)
|
857 |
+
state_next = state * dA + dB * x
|
858 |
+
```
|
859 |
+
|
860 |
+
To avoid updating state, we set dt to -inf and x to 0
|
861 |
+
because `softplus(-inf) = 0` and `exp(0) = 1`
|
862 |
+
"""
|
863 |
+
dt = torch.where(bool_mask[:, :, None, None], dt, float("-inf"))
|
864 |
+
x = torch.where(bool_mask[:, :, None, None], x, 0.0)
|
865 |
+
|
866 |
+
# ssm
|
867 |
+
if is_update:
|
868 |
+
assert ssm_state is not None
|
869 |
+
out = ssd_update_state(
|
870 |
+
ssm_state,
|
871 |
+
x[:, 0],
|
872 |
+
dt[:, 0].reshape(bsize, -1),
|
873 |
+
A,
|
874 |
+
B[:, 0],
|
875 |
+
C[:, 0],
|
876 |
+
D=self.D,
|
877 |
+
z=z[:, 0],
|
878 |
+
dt_bias=self.dt_bias,
|
879 |
+
dt_softplus=True,
|
880 |
+
)
|
881 |
+
else:
|
882 |
+
tmp = ssd_chunk_scan_combined(
|
883 |
+
x,
|
884 |
+
dt.reshape(bsize, length, -1),
|
885 |
+
A,
|
886 |
+
B,
|
887 |
+
C,
|
888 |
+
self.chunk_size,
|
889 |
+
D=self.D,
|
890 |
+
z=z,
|
891 |
+
dt_bias=self.dt_bias,
|
892 |
+
dt_softplus=True,
|
893 |
+
return_final_states=past_states is not None,
|
894 |
+
seq_idx=seq_idx,
|
895 |
+
ssm_state=ssm_state,
|
896 |
+
)
|
897 |
+
if past_states is not None:
|
898 |
+
out, ssm_state = tmp
|
899 |
+
else:
|
900 |
+
assert isinstance(tmp, torch.Tensor)
|
901 |
+
out = tmp
|
902 |
+
|
903 |
+
y = self.out_proj(out.reshape(bsize, length, -1))
|
904 |
+
|
905 |
+
if past_states is not None:
|
906 |
+
assert ssm_state is not None
|
907 |
+
assert conv_state is not None
|
908 |
+
past_states.update_mamba(conv_state, ssm_state, self.layer_idx)
|
909 |
+
|
910 |
+
return y, past_states
|
911 |
+
|
912 |
+
|
913 |
+
def swa_mask(q_len: int, kv_len: int, device: torch.device, window_size: int) -> torch.Tensor:
|
914 |
+
max_len = max(q_len, kv_len)
|
915 |
+
mask = (
|
916 |
+
torch.ones(max_len, max_len, dtype=torch.bool, device=device)
|
917 |
+
.triu(diagonal=-window_size)
|
918 |
+
.tril(diagonal=window_size)
|
919 |
+
)
|
920 |
+
return mask[-q_len:, -kv_len:]
|
921 |
+
|
922 |
+
|
923 |
+
class Attention(torch.nn.Module):
|
924 |
+
def __init__(self, config: PlamoConfig, layer_idx: int) -> None:
|
925 |
+
super().__init__()
|
926 |
+
self.config = config
|
927 |
+
self.layer_idx = layer_idx
|
928 |
+
self.hidden_size = config.hidden_size
|
929 |
+
head_dim = config.hidden_size_per_head
|
930 |
+
self.max_position_embeddings = config.max_position_embeddings
|
931 |
+
|
932 |
+
self.q_num_heads = config.num_attention_heads
|
933 |
+
self.qk_dim = self.v_dim = head_dim
|
934 |
+
self.k_num_heads = self.v_num_heads = config.num_key_value_heads
|
935 |
+
assert self.q_num_heads % self.k_num_heads == 0
|
936 |
+
self.n_group = self.q_num_heads // self.k_num_heads
|
937 |
+
|
938 |
+
self.q_proj_dim = self.q_num_heads * self.qk_dim
|
939 |
+
self.k_proj_dim = self.k_num_heads * self.qk_dim
|
940 |
+
self.v_proj_dim = self.k_num_heads * self.v_dim
|
941 |
+
self.qkv_proj = nn.Linear(self.hidden_size, self.q_proj_dim + self.k_proj_dim + self.v_proj_dim, bias=False)
|
942 |
+
self.o_proj = nn.Linear(self.q_num_heads * self.v_dim, self.hidden_size, bias=False)
|
943 |
+
|
944 |
+
self.q_weight = torch.nn.Parameter(torch.ones((self.q_num_heads, self.qk_dim)))
|
945 |
+
self.k_weight = torch.nn.Parameter(torch.ones((self.k_num_heads, self.qk_dim)))
|
946 |
+
|
947 |
+
self.rotary_emb = RotaryEmbedding(self.qk_dim, max_position_embeddings=self.config.attention_window_size)
|
948 |
+
|
949 |
+
def forward(
|
950 |
+
self,
|
951 |
+
hidden_states: torch.Tensor,
|
952 |
+
attention_mask: Optional[torch.Tensor] = None,
|
953 |
+
past_states: Optional[PlamoCache] = None,
|
954 |
+
output_attentions: bool = False,
|
955 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[PlamoCache]]:
|
956 |
+
bsz, q_len, _ = hidden_states.size()
|
957 |
+
|
958 |
+
qkv = self.qkv_proj(hidden_states)
|
959 |
+
query_states, key_states, value_states = torch.split(
|
960 |
+
qkv, [self.q_proj_dim, self.k_proj_dim, self.v_proj_dim], dim=-1
|
961 |
+
)
|
962 |
+
query_states = query_states.view(bsz, q_len, self.q_num_heads, self.qk_dim).transpose(1, 2)
|
963 |
+
key_states = key_states.view(bsz, q_len, self.k_num_heads, self.qk_dim).transpose(1, 2)
|
964 |
+
value_states = value_states.view(bsz, q_len, self.v_num_heads, self.v_dim).transpose(1, 2)
|
965 |
+
|
966 |
+
attn_dtype = query_states.dtype
|
967 |
+
|
968 |
+
query_states = _rms_norm(query_states, None, 1e-6) * self.q_weight[None, :, None]
|
969 |
+
key_states = _rms_norm(key_states, None, 1e-6) * self.k_weight[None, :, None]
|
970 |
+
|
971 |
+
if past_states is not None and past_states[self.layer_idx] is None:
|
972 |
+
bsz, nhead_k, _, c_k = key_states.shape
|
973 |
+
_, nhead_v, _, c_v = value_states.shape
|
974 |
+
past_states.update_attention(
|
975 |
+
torch.zeros((bsz, nhead_k, 0, c_k), dtype=key_states.dtype, device=key_states.device),
|
976 |
+
torch.zeros((bsz, nhead_v, 0, c_v), dtype=value_states.dtype, device=value_states.device),
|
977 |
+
self.layer_idx,
|
978 |
+
)
|
979 |
+
|
980 |
+
if past_states is not None:
|
981 |
+
# reuse k, v, self_attention
|
982 |
+
key_states_new = key_states
|
983 |
+
value_states_new = value_states
|
984 |
+
key_states, value_states = past_states.append_kv(key_states, value_states, self.layer_idx) # type: ignore
|
985 |
+
past_states.update_attention(key_states_new, value_states_new, self.layer_idx)
|
986 |
+
|
987 |
+
kv_seq_len = key_states.shape[-2]
|
988 |
+
device = hidden_states.device
|
989 |
+
position_ids = torch.arange(kv_seq_len, dtype=torch.long, device=device)[None]
|
990 |
+
q_position_ids = position_ids[:, -query_states.shape[2] :]
|
991 |
+
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
992 |
+
query_states = _rotary_pos_emb(query_states, cos, sin, q_position_ids)
|
993 |
+
key_states = _rotary_pos_emb(key_states, cos, sin, position_ids)
|
994 |
+
# [bsz, nh, t, hd]
|
995 |
+
|
996 |
+
def _expand_kv(t: torch.Tensor, repeat: int, target: int) -> torch.Tensor:
|
997 |
+
t = torch.repeat_interleave(t, repeat, dim=1)
|
998 |
+
return t[:, :target]
|
999 |
+
|
1000 |
+
# expand shared kv
|
1001 |
+
assert self.k_num_heads == self.v_num_heads
|
1002 |
+
key_states = _expand_kv(key_states, self.n_group, self.q_num_heads)
|
1003 |
+
value_states = _expand_kv(value_states, self.n_group, self.q_num_heads)
|
1004 |
+
|
1005 |
+
full_attn = self.layer_idx in self.config.full_attention_idx
|
1006 |
+
|
1007 |
+
query_states = query_states.to(attn_dtype)
|
1008 |
+
key_states = key_states.to(attn_dtype)
|
1009 |
+
value_states = value_states.to(attn_dtype)
|
1010 |
+
if attention_mask is not None and attention_mask.dtype != torch.bool:
|
1011 |
+
attention_mask = attention_mask.to(attn_dtype)
|
1012 |
+
if attention_mask is None:
|
1013 |
+
if not full_attn:
|
1014 |
+
assert key_states.shape[2] <= self.config.attention_window_size + 1
|
1015 |
+
attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, is_causal=True)
|
1016 |
+
else:
|
1017 |
+
if attention_mask.dtype == torch.bool:
|
1018 |
+
attention_mask = torch.where(attention_mask, torch.tensor(0.0, dtype=torch.float), float("-inf"))
|
1019 |
+
if len(attention_mask.shape) == 2:
|
1020 |
+
attention_mask = attention_mask[None, None]
|
1021 |
+
assert len(attention_mask.shape) == 4
|
1022 |
+
|
1023 |
+
if not full_attn:
|
1024 |
+
m_swa = swa_mask(
|
1025 |
+
query_states.shape[2], key_states.shape[2], query_states.device, self.config.attention_window_size
|
1026 |
+
)
|
1027 |
+
# `generate` function creates attention mask that does not consider sliding window
|
1028 |
+
m_swa = m_swa[None, None]
|
1029 |
+
attention_mask = attention_mask[:, :, -query_states.shape[2] :, -key_states.shape[2] :]
|
1030 |
+
attention_mask = torch.where(m_swa, attention_mask, float("-inf"))
|
1031 |
+
|
1032 |
+
# like AttentionMaskConverter._unmask_unattended in huggingface.transfoermers,
|
1033 |
+
# we need to attend to all tokens in masked rows for `scaled_dot_product_attention`
|
1034 |
+
bool_mask = torch.logical_not(torch.isneginf(attention_mask))
|
1035 |
+
valid_tokens = torch.sum(bool_mask, dim=-1).bool() # (..., q_len)
|
1036 |
+
attention_mask = torch.where(valid_tokens[..., None], attention_mask, float(0.0))
|
1037 |
+
attn_output = F.scaled_dot_product_attention(
|
1038 |
+
query_states, key_states, value_states, attn_mask=attention_mask
|
1039 |
+
)
|
1040 |
+
|
1041 |
+
attn_output = attn_output.transpose(1, 2)
|
1042 |
+
|
1043 |
+
attn_output = attn_output.reshape(bsz, q_len, self.q_num_heads * self.v_dim)
|
1044 |
+
attn_output = self.o_proj(attn_output)
|
1045 |
+
|
1046 |
+
if not output_attentions:
|
1047 |
+
attn_weights = None
|
1048 |
+
|
1049 |
+
return attn_output, attn_weights, past_states
|
1050 |
+
|
1051 |
+
|
1052 |
+
class MLP(nn.Module):
|
1053 |
+
def __init__(self, config: PlamoConfig) -> None:
|
1054 |
+
super().__init__()
|
1055 |
+
self.config = config
|
1056 |
+
self.hidden_size = config.hidden_size
|
1057 |
+
self.intermediate_size = config.intermediate_size
|
1058 |
+
self.gate_up_proj = torch.nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=False)
|
1059 |
+
self.down_proj = torch.nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
1060 |
+
|
1061 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
1062 |
+
h = self.gate_up_proj(x)
|
1063 |
+
h = _swiglu(h)
|
1064 |
+
return self.down_proj(h) # type: ignore
|
1065 |
+
|
1066 |
+
|
1067 |
+
class PlamoDecoderLayer(torch.nn.Module):
|
1068 |
+
def __init__(self, config: PlamoConfig, is_mamba: bool, layer_idx: int) -> None:
|
1069 |
+
super().__init__()
|
1070 |
+
self.config = config
|
1071 |
+
self.hidden_size = config.hidden_size
|
1072 |
+
self.is_mamba = is_mamba
|
1073 |
+
self.mixer: torch.nn.Module
|
1074 |
+
if is_mamba:
|
1075 |
+
self.mixer = Mamba(config, layer_idx)
|
1076 |
+
else:
|
1077 |
+
self.mixer = Attention(config, layer_idx)
|
1078 |
+
self.mlp = MLP(config)
|
1079 |
+
"""
|
1080 |
+
Notes: The model performance was degraded when setting all offsets to 1.
|
1081 |
+
"""
|
1082 |
+
self.pre_mixer_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, offset=1.0)
|
1083 |
+
self.post_mixer_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, offset=1.0 / 5)
|
1084 |
+
self.pre_mlp_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, offset=1.0)
|
1085 |
+
self.post_mlp_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, offset=1.0 / (5**1.5))
|
1086 |
+
|
1087 |
+
def forward(
|
1088 |
+
self,
|
1089 |
+
hidden_states: torch.Tensor,
|
1090 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1091 |
+
past_state: Optional[PlamoCache] = None,
|
1092 |
+
output_attentions: Optional[bool] = False,
|
1093 |
+
) -> Tuple[Any, ...]:
|
1094 |
+
# from LlamaDecoder
|
1095 |
+
residual = hidden_states
|
1096 |
+
hidden_states = self.pre_mixer_norm(hidden_states)
|
1097 |
+
|
1098 |
+
# Self Attention
|
1099 |
+
if self.is_mamba:
|
1100 |
+
hidden_states_sa, present_key_value = self.mixer(
|
1101 |
+
hidden_states=hidden_states,
|
1102 |
+
attention_mask=attention_mask,
|
1103 |
+
past_states=past_state,
|
1104 |
+
)
|
1105 |
+
self_attn_weights = None
|
1106 |
+
else:
|
1107 |
+
hidden_states_sa, self_attn_weights, present_key_value = self.mixer(
|
1108 |
+
hidden_states=hidden_states,
|
1109 |
+
attention_mask=attention_mask,
|
1110 |
+
past_states=past_state,
|
1111 |
+
output_attentions=output_attentions,
|
1112 |
+
)
|
1113 |
+
|
1114 |
+
hidden_states_sa = self.post_mixer_norm(hidden_states_sa)
|
1115 |
+
hidden_states = residual + hidden_states_sa
|
1116 |
+
|
1117 |
+
residual = hidden_states
|
1118 |
+
hidden_states = self.pre_mlp_norm(hidden_states)
|
1119 |
+
|
1120 |
+
# Fully Connected
|
1121 |
+
hidden_states_mlp = self.mlp(hidden_states)
|
1122 |
+
|
1123 |
+
# Residual
|
1124 |
+
hidden_states_mlp = self.post_mlp_norm(hidden_states_mlp)
|
1125 |
+
hidden_states = residual + hidden_states_mlp
|
1126 |
+
|
1127 |
+
outputs: Any = (hidden_states,)
|
1128 |
+
|
1129 |
+
if output_attentions:
|
1130 |
+
outputs += (self_attn_weights,)
|
1131 |
+
|
1132 |
+
return outputs # type: ignore
|
1133 |
+
|
1134 |
+
|
1135 |
+
def is_mamba(config: PlamoConfig, i: int) -> bool:
|
1136 |
+
if not config.mamba_enabled:
|
1137 |
+
return False
|
1138 |
+
assert config.mamba_step > 1
|
1139 |
+
assert i < config.num_hidden_layers
|
1140 |
+
|
1141 |
+
if config.num_hidden_layers <= (config.mamba_step // 2):
|
1142 |
+
# use attention in last layer
|
1143 |
+
return i != config.num_hidden_layers - 1
|
1144 |
+
return (i % config.mamba_step) != (config.mamba_step // 2)
|
1145 |
+
|
1146 |
+
|
1147 |
+
class PlamoDecoder(torch.nn.Module):
|
1148 |
+
def __init__(self, config: PlamoConfig) -> None:
|
1149 |
+
super().__init__()
|
1150 |
+
|
1151 |
+
self.layers = torch.nn.ModuleList(
|
1152 |
+
[
|
1153 |
+
PlamoDecoderLayer(config, is_mamba=is_mamba(config, i), layer_idx=i)
|
1154 |
+
for i in range(config.num_hidden_layers)
|
1155 |
+
]
|
1156 |
+
)
|
1157 |
+
|
1158 |
+
def forward(self, x: DecoderInput) -> DecoderOutput:
|
1159 |
+
all_hidden_states: Optional[Tuple[torch.Tensor, ...]] = () if x.output_hidden_states else None
|
1160 |
+
all_self_attns: Optional[Tuple[torch.Tensor, ...]] = () if x.output_attentions else None
|
1161 |
+
hidden_states = x.hidden_states
|
1162 |
+
|
1163 |
+
for decoder_layer in self.layers:
|
1164 |
+
if x.output_hidden_states:
|
1165 |
+
assert all_hidden_states is not None
|
1166 |
+
all_hidden_states += (hidden_states,)
|
1167 |
+
|
1168 |
+
if self.training and x.gradient_checkpointing:
|
1169 |
+
|
1170 |
+
def create_custom_forward(module): # type: ignore
|
1171 |
+
def custom_forward(*inputs): # type: ignore
|
1172 |
+
# None for past_key_value
|
1173 |
+
return module(*inputs, x.output_attentions, None)
|
1174 |
+
|
1175 |
+
return custom_forward
|
1176 |
+
|
1177 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
1178 |
+
create_custom_forward(decoder_layer), # type: ignore
|
1179 |
+
hidden_states,
|
1180 |
+
x.attention_mask,
|
1181 |
+
None,
|
1182 |
+
)
|
1183 |
+
else:
|
1184 |
+
layer_outputs = decoder_layer(
|
1185 |
+
hidden_states,
|
1186 |
+
attention_mask=x.attention_mask,
|
1187 |
+
past_state=x.past_states,
|
1188 |
+
output_attentions=x.output_attentions,
|
1189 |
+
)
|
1190 |
+
|
1191 |
+
hidden_states = layer_outputs[0]
|
1192 |
+
|
1193 |
+
if x.output_attentions:
|
1194 |
+
assert layer_outputs[1] is not None
|
1195 |
+
assert all_self_attns is not None
|
1196 |
+
all_self_attns += (layer_outputs[1],)
|
1197 |
+
return DecoderOutput(hidden_states, all_hidden_states, all_self_attns)
|
1198 |
+
|
1199 |
+
|
1200 |
+
class PlamoPreTrainedModel(PreTrainedModel): # type: ignore
|
1201 |
+
config_class = PlamoConfig
|
1202 |
+
_no_split_modules: List[str]
|
1203 |
+
base_model_prefix = "model"
|
1204 |
+
supports_gradient_checkpointing = True
|
1205 |
+
_no_split_modules = ["PlamoDecoderLayer"]
|
1206 |
+
_skip_keys_device_placement = "past_key_values"
|
1207 |
+
_keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
|
1208 |
+
|
1209 |
+
def _init_weights(self, module: torch.nn.Module) -> None:
|
1210 |
+
std = 0.02
|
1211 |
+
if isinstance(module, nn.Linear):
|
1212 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
1213 |
+
if module.bias is not None:
|
1214 |
+
module.bias.data.zero_()
|
1215 |
+
elif isinstance(module, nn.Embedding):
|
1216 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
1217 |
+
if module.padding_idx is not None:
|
1218 |
+
module.weight.data[module.padding_idx].zero_()
|
1219 |
+
|
1220 |
+
def _set_gradient_checkpointing(self, module: torch.nn.Module, value: bool = False) -> None:
|
1221 |
+
module.gradient_checkpointing = value # type: ignore
|
1222 |
+
|
1223 |
+
|
1224 |
+
class PlamoModel(PlamoPreTrainedModel):
|
1225 |
+
def __init__(self, config: PlamoConfig):
|
1226 |
+
super().__init__(config)
|
1227 |
+
assert config.eval_attention_n_bit is None
|
1228 |
+
assert config.eval_mlp_n_bit is None
|
1229 |
+
|
1230 |
+
self.padding_idx = config.pad_token_id
|
1231 |
+
self.vocab_size = config.vocab_size
|
1232 |
+
|
1233 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
1234 |
+
if config.image_feature_size is not None:
|
1235 |
+
if config.image_proj_type == "mlp":
|
1236 |
+
self.image_proj = MLPImageProjector(config) # type: ignore
|
1237 |
+
elif config.image_proj_type == "linear":
|
1238 |
+
self.image_proj = nn.Linear(config.image_feature_size, config.hidden_size, bias=False) # type: ignore
|
1239 |
+
else:
|
1240 |
+
raise ValueError(f"Unknown image_proj_type: {config.image_proj_type}")
|
1241 |
+
self.layers = PlamoDecoder(config) # type: ignore
|
1242 |
+
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
1243 |
+
|
1244 |
+
self.gradient_checkpointing = False
|
1245 |
+
# Initialize weights and apply final processing
|
1246 |
+
self.post_init()
|
1247 |
+
|
1248 |
+
def get_input_embeddings(self) -> torch.nn.Embedding:
|
1249 |
+
return self.embed_tokens
|
1250 |
+
|
1251 |
+
def set_input_embeddings(self, value: torch.nn.Embedding) -> None:
|
1252 |
+
self.embed_tokens = value
|
1253 |
+
|
1254 |
+
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
|
1255 |
+
def _prepare_decoder_attention_mask(
|
1256 |
+
self,
|
1257 |
+
attention_mask: torch.Tensor,
|
1258 |
+
input_shape: Tuple[int, int],
|
1259 |
+
inputs_embeds: Optional[torch.Tensor],
|
1260 |
+
past_key_values_length: int,
|
1261 |
+
) -> Optional[torch.Tensor]:
|
1262 |
+
# create causal mask
|
1263 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
1264 |
+
combined_attention_mask: Optional[torch.Tensor] = None
|
1265 |
+
if input_shape[-1] > 1:
|
1266 |
+
assert inputs_embeds is not None
|
1267 |
+
combined_attention_mask = _make_causal_mask(
|
1268 |
+
input_shape,
|
1269 |
+
inputs_embeds.dtype,
|
1270 |
+
device=inputs_embeds.device,
|
1271 |
+
past_key_values_length=past_key_values_length,
|
1272 |
+
)
|
1273 |
+
input_shape = (input_shape[0], combined_attention_mask.shape[2])
|
1274 |
+
|
1275 |
+
if attention_mask is not None:
|
1276 |
+
if attention_mask.dim() == 4:
|
1277 |
+
# Custom 4D attention mask
|
1278 |
+
expanded_attn_mask = attention_mask
|
1279 |
+
else:
|
1280 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
1281 |
+
assert inputs_embeds is not None
|
1282 |
+
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
|
1283 |
+
inputs_embeds.device
|
1284 |
+
)
|
1285 |
+
combined_attention_mask = (
|
1286 |
+
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
|
1287 |
+
)
|
1288 |
+
|
1289 |
+
return combined_attention_mask
|
1290 |
+
|
1291 |
+
def forward(
|
1292 |
+
self,
|
1293 |
+
input_ids: Optional[torch.LongTensor] = None,
|
1294 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1295 |
+
position_ids: Optional[torch.Tensor] = None,
|
1296 |
+
past_key_values: Optional[PlamoCache] = None,
|
1297 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
1298 |
+
image_features: Optional[torch.Tensor] = None,
|
1299 |
+
use_cache: Optional[bool] = None,
|
1300 |
+
output_attentions: Optional[bool] = None,
|
1301 |
+
output_hidden_states: Optional[bool] = None,
|
1302 |
+
return_dict: Optional[bool] = None,
|
1303 |
+
) -> Union[Tuple, BaseModelOutputWithPast]:
|
1304 |
+
assert input_ids is not None
|
1305 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
1306 |
+
output_hidden_states = (
|
1307 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
1308 |
+
)
|
1309 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
1310 |
+
|
1311 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1312 |
+
|
1313 |
+
# retrieve input_ids and inputs_embeds
|
1314 |
+
if input_ids is not None and inputs_embeds is not None:
|
1315 |
+
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
1316 |
+
elif input_ids is not None:
|
1317 |
+
batch_size, seq_length = input_ids.shape
|
1318 |
+
else:
|
1319 |
+
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
1320 |
+
|
1321 |
+
seq_length_with_past = seq_length
|
1322 |
+
past_key_values_length = 0
|
1323 |
+
|
1324 |
+
if past_key_values is not None:
|
1325 |
+
past_key_values_length = past_key_values.get_seq_length()
|
1326 |
+
seq_length_with_past = seq_length_with_past + past_key_values_length
|
1327 |
+
|
1328 |
+
if inputs_embeds is None:
|
1329 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
1330 |
+
|
1331 |
+
if image_features is not None:
|
1332 |
+
assert self.config.image_token_id is not None
|
1333 |
+
image_embeds = self.image_proj(image_features)
|
1334 |
+
assert image_embeds.shape == inputs_embeds.shape, (image_embeds.shape, inputs_embeds.shape)
|
1335 |
+
mask = input_ids == self.config.image_token_id
|
1336 |
+
inputs_embeds[mask] = image_embeds[mask]
|
1337 |
+
|
1338 |
+
# embed positions
|
1339 |
+
require_attn_mask = False
|
1340 |
+
if not self.training or past_key_values is not None:
|
1341 |
+
require_attn_mask = True
|
1342 |
+
if seq_length_with_past >= self.config.attention_window_size:
|
1343 |
+
require_attn_mask = True
|
1344 |
+
if require_attn_mask and attention_mask is None:
|
1345 |
+
attention_mask = torch.ones(
|
1346 |
+
(batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
|
1347 |
+
)
|
1348 |
+
if attention_mask is not None:
|
1349 |
+
attention_mask = self._prepare_decoder_attention_mask(
|
1350 |
+
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
|
1351 |
+
)
|
1352 |
+
|
1353 |
+
hidden_states = inputs_embeds
|
1354 |
+
|
1355 |
+
if self.gradient_checkpointing and self.training:
|
1356 |
+
if use_cache:
|
1357 |
+
use_cache = False
|
1358 |
+
|
1359 |
+
if use_cache and past_key_values is None:
|
1360 |
+
past_key_values = PlamoCache(self.config)
|
1361 |
+
|
1362 |
+
# decoder layers
|
1363 |
+
out = self.layers(
|
1364 |
+
DecoderInput(
|
1365 |
+
hidden_states,
|
1366 |
+
attention_mask,
|
1367 |
+
past_key_values,
|
1368 |
+
output_hidden_states,
|
1369 |
+
output_attentions,
|
1370 |
+
self.gradient_checkpointing,
|
1371 |
+
)
|
1372 |
+
)
|
1373 |
+
assert isinstance(out, DecoderOutput)
|
1374 |
+
hidden_states = out.hidden_states
|
1375 |
+
all_hidden_states = out.all_hidden_states
|
1376 |
+
all_self_attns = out.all_self_attns
|
1377 |
+
|
1378 |
+
hidden_states = self.norm(hidden_states)
|
1379 |
+
|
1380 |
+
# add hidden states from the last decoder layer
|
1381 |
+
if output_hidden_states:
|
1382 |
+
assert all_hidden_states is not None
|
1383 |
+
all_hidden_states += (hidden_states,)
|
1384 |
+
|
1385 |
+
if not return_dict:
|
1386 |
+
return tuple(
|
1387 |
+
v for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns] if v is not None
|
1388 |
+
)
|
1389 |
+
return BaseModelOutputWithPast(
|
1390 |
+
last_hidden_state=hidden_states,
|
1391 |
+
past_key_values=past_key_values,
|
1392 |
+
hidden_states=all_hidden_states,
|
1393 |
+
attentions=all_self_attns,
|
1394 |
+
)
|
1395 |
+
|
1396 |
+
|
1397 |
+
class PlamoForCausalLM(PlamoPreTrainedModel):
|
1398 |
+
_tied_weights_keys = ["lm_head.weight"]
|
1399 |
+
|
1400 |
+
# Without this, the model cannot be loaded into a meta device.
|
1401 |
+
# Relevant code:
|
1402 |
+
# https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/modeling_utils.py#L4376-L4381
|
1403 |
+
# https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/modeling_utils.py#L356
|
1404 |
+
# https://github.com/pytorch/pytorch/blob/v2.4.1/torch/nn/modules/module.py#L2068
|
1405 |
+
_supports_param_buffer_assignment = False
|
1406 |
+
|
1407 |
+
def __init__(self, config: PlamoConfig) -> None:
|
1408 |
+
super().__init__(config)
|
1409 |
+
self.model = PlamoModel(config)
|
1410 |
+
|
1411 |
+
self.vocab_size = config.vocab_size
|
1412 |
+
vocab_size = ((self.vocab_size + 15) // 16) * 16
|
1413 |
+
self.lm_head: torch.nn.Module = nn.Linear(config.hidden_size, vocab_size, bias=False)
|
1414 |
+
|
1415 |
+
# Initialize weights and apply final processing
|
1416 |
+
self.post_init()
|
1417 |
+
|
1418 |
+
def get_input_embeddings(self) -> torch.nn.Embedding:
|
1419 |
+
return self.model.embed_tokens
|
1420 |
+
|
1421 |
+
def set_input_embeddings(self, value: torch.nn.Embedding) -> None:
|
1422 |
+
self.model.embed_tokens = value
|
1423 |
+
|
1424 |
+
def get_output_embeddings(self) -> torch.nn.Module:
|
1425 |
+
return self.lm_head
|
1426 |
+
|
1427 |
+
def set_output_embeddings(self, new_embeddings: torch.nn.Module) -> None:
|
1428 |
+
self.lm_head = new_embeddings
|
1429 |
+
|
1430 |
+
def set_decoder(self, decoder: PlamoModel) -> None:
|
1431 |
+
self.model = decoder
|
1432 |
+
|
1433 |
+
def get_decoder(self) -> PlamoModel:
|
1434 |
+
return self.model
|
1435 |
+
|
1436 |
+
def forward( # type: ignore
|
1437 |
+
self,
|
1438 |
+
input_ids: Optional[torch.LongTensor] = None,
|
1439 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1440 |
+
position_ids: Optional[torch.Tensor] = None,
|
1441 |
+
past_key_values: Optional[PlamoCache] = None,
|
1442 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
1443 |
+
image_features: Optional[torch.Tensor] = None,
|
1444 |
+
labels: Optional[torch.LongTensor] = None,
|
1445 |
+
use_cache: Optional[bool] = None,
|
1446 |
+
output_attentions: Optional[bool] = None,
|
1447 |
+
output_hidden_states: Optional[bool] = None,
|
1448 |
+
return_dict: Optional[bool] = None,
|
1449 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
1450 |
+
r"""
|
1451 |
+
Args:
|
1452 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
1453 |
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
1454 |
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
1455 |
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
1456 |
+
|
1457 |
+
Returns:
|
1458 |
+
|
1459 |
+
Example:
|
1460 |
+
|
1461 |
+
```python
|
1462 |
+
>>> from transformers import AutoTokenizer, LlamaForCausalLM
|
1463 |
+
|
1464 |
+
>>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
|
1465 |
+
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
|
1466 |
+
|
1467 |
+
>>> prompt = "Hey, are you consciours? Can you talk to me?"
|
1468 |
+
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
1469 |
+
|
1470 |
+
>>> # Generate
|
1471 |
+
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
1472 |
+
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
1473 |
+
"Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
|
1474 |
+
```"""
|
1475 |
+
assert input_ids is not None
|
1476 |
+
|
1477 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
1478 |
+
output_hidden_states = (
|
1479 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
1480 |
+
)
|
1481 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1482 |
+
|
1483 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
1484 |
+
outputs = self.model(
|
1485 |
+
input_ids=input_ids,
|
1486 |
+
attention_mask=attention_mask,
|
1487 |
+
position_ids=position_ids,
|
1488 |
+
past_key_values=past_key_values,
|
1489 |
+
inputs_embeds=inputs_embeds,
|
1490 |
+
image_features=image_features,
|
1491 |
+
use_cache=use_cache,
|
1492 |
+
output_attentions=output_attentions,
|
1493 |
+
output_hidden_states=output_hidden_states,
|
1494 |
+
return_dict=return_dict,
|
1495 |
+
)
|
1496 |
+
|
1497 |
+
hidden_states = outputs[0]
|
1498 |
+
logits = self.lm_head(hidden_states)
|
1499 |
+
logits = logits[..., : self.vocab_size]
|
1500 |
+
|
1501 |
+
loss = None
|
1502 |
+
if labels is not None:
|
1503 |
+
# Shift so that tokens < n predict n
|
1504 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
1505 |
+
shift_labels = labels[..., 1:].contiguous()
|
1506 |
+
# Flatten the tokens
|
1507 |
+
loss_fct = nn.CrossEntropyLoss()
|
1508 |
+
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
1509 |
+
shift_labels = shift_labels.view(-1)
|
1510 |
+
# Enable model parallelism
|
1511 |
+
shift_labels = shift_labels.to(shift_logits.device)
|
1512 |
+
loss = loss_fct(shift_logits, shift_labels)
|
1513 |
+
|
1514 |
+
if not return_dict:
|
1515 |
+
output = (logits,) + outputs[1:]
|
1516 |
+
return (loss,) + output if loss is not None else output
|
1517 |
+
|
1518 |
+
return CausalLMOutputWithPast(
|
1519 |
+
loss=loss,
|
1520 |
+
logits=logits,
|
1521 |
+
past_key_values=outputs.past_key_values,
|
1522 |
+
hidden_states=outputs.hidden_states,
|
1523 |
+
attentions=outputs.attentions,
|
1524 |
+
)
|
1525 |
+
|
1526 |
+
def prepare_inputs_for_generation(
|
1527 |
+
self,
|
1528 |
+
input_ids: torch.Tensor,
|
1529 |
+
past_key_values: Optional[PlamoCache] = None,
|
1530 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1531 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
1532 |
+
image_features: Optional[torch.Tensor] = None,
|
1533 |
+
**kwargs: Any,
|
1534 |
+
) -> Dict[str, Any]:
|
1535 |
+
if past_key_values:
|
1536 |
+
input_ids = input_ids[:, -1:]
|
1537 |
+
if image_features is not None:
|
1538 |
+
image_features = image_features[:, -1:, :]
|
1539 |
+
|
1540 |
+
position_ids = kwargs.get("position_ids", None)
|
1541 |
+
if attention_mask is not None and position_ids is None:
|
1542 |
+
# create position_ids on the fly for batch generation
|
1543 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
1544 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
1545 |
+
if past_key_values:
|
1546 |
+
position_ids = position_ids[:, -1].unsqueeze(-1)
|
1547 |
+
|
1548 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
1549 |
+
if inputs_embeds is not None and past_key_values is None:
|
1550 |
+
model_inputs: Dict[str, Any] = {"inputs_embeds": inputs_embeds}
|
1551 |
+
else:
|
1552 |
+
model_inputs = {"input_ids": input_ids}
|
1553 |
+
|
1554 |
+
model_inputs.update(
|
1555 |
+
{
|
1556 |
+
"position_ids": position_ids,
|
1557 |
+
"past_key_values": past_key_values,
|
1558 |
+
"use_cache": kwargs.get("use_cache"),
|
1559 |
+
"attention_mask": attention_mask,
|
1560 |
+
"image_features": image_features,
|
1561 |
+
}
|
1562 |
+
)
|
1563 |
+
return model_inputs
|
1564 |
+
|
1565 |
+
@staticmethod
|
1566 |
+
def _reorder_cache(past_key_values: PlamoCache, beam_idx: torch.Tensor) -> PlamoCache:
|
1567 |
+
past_key_values.reorder_cache(beam_idx)
|
1568 |
+
return past_key_values
|
1569 |
+
|
1570 |
+
|
1571 |
+
class MLPImageProjector(nn.Module):
|
1572 |
+
def __init__(self, config: PlamoConfig) -> None:
|
1573 |
+
super().__init__()
|
1574 |
+
self.config = config
|
1575 |
+
|
1576 |
+
assert config.image_feature_size is not None # for typing
|
1577 |
+
|
1578 |
+
# nn.LayerNorm is not supported by PFVM, so use RMSNorm + Bias instead to approximate this.
|
1579 |
+
self.norm0 = RMSNorm(config.image_feature_size, eps=config.rms_norm_eps)
|
1580 |
+
self.bias0 = Bias(config.image_feature_size)
|
1581 |
+
|
1582 |
+
# PFVM doesn't support Linear with bias, so add bias manually afterwards.
|
1583 |
+
self.linear1 = nn.Linear(config.image_feature_size, config.hidden_size, bias=False)
|
1584 |
+
self.bias1 = Bias(config.hidden_size)
|
1585 |
+
self.act1 = nn.GELU()
|
1586 |
+
|
1587 |
+
self.linear2 = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
|
1588 |
+
self.bias2 = Bias(config.hidden_size)
|
1589 |
+
|
1590 |
+
def forward(
|
1591 |
+
self,
|
1592 |
+
hidden_states: torch.Tensor,
|
1593 |
+
) -> torch.Tensor:
|
1594 |
+
hidden_states = self.norm0(hidden_states)
|
1595 |
+
hidden_states = self.bias0(hidden_states)
|
1596 |
+
|
1597 |
+
hidden_states = self.linear1(hidden_states)
|
1598 |
+
hidden_states = self.bias1(hidden_states)
|
1599 |
+
hidden_states = self.act1(hidden_states)
|
1600 |
+
|
1601 |
+
hidden_states = self.linear2(hidden_states)
|
1602 |
+
hidden_states = self.bias2(hidden_states)
|
1603 |
+
|
1604 |
+
return hidden_states
|
1605 |
+
|
1606 |
+
|
1607 |
+
class Bias(nn.Module):
|
1608 |
+
def __init__(self, num_features: int) -> None:
|
1609 |
+
super().__init__()
|
1610 |
+
self._bias = nn.Parameter(torch.zeros((num_features,)))
|
1611 |
+
|
1612 |
+
def forward(
|
1613 |
+
self,
|
1614 |
+
x: torch.Tensor,
|
1615 |
+
) -> torch.Tensor:
|
1616 |
+
return x + self._bias
|
special_tokens_map.json
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"bos_token": {
|
3 |
+
"content": "<|plamo:bos|>",
|
4 |
+
"lstrip": false,
|
5 |
+
"normalized": false,
|
6 |
+
"rstrip": false,
|
7 |
+
"single_word": false
|
8 |
+
},
|
9 |
+
"eos_token": {
|
10 |
+
"content": "<|plamo:eos|>",
|
11 |
+
"lstrip": false,
|
12 |
+
"normalized": false,
|
13 |
+
"rstrip": false,
|
14 |
+
"single_word": false
|
15 |
+
},
|
16 |
+
"pad_token": {
|
17 |
+
"content": "<|plamo:pad|>",
|
18 |
+
"lstrip": false,
|
19 |
+
"normalized": false,
|
20 |
+
"rstrip": false,
|
21 |
+
"single_word": false
|
22 |
+
},
|
23 |
+
"unk_token": {
|
24 |
+
"content": "<|plamo:unk|>",
|
25 |
+
"lstrip": false,
|
26 |
+
"normalized": false,
|
27 |
+
"rstrip": false,
|
28 |
+
"single_word": false
|
29 |
+
}
|
30 |
+
}
|
tokenization_plamo.py
ADDED
@@ -0,0 +1,392 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import math
|
3 |
+
import os
|
4 |
+
from shutil import copyfile
|
5 |
+
from typing import Any, Optional, Tuple
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
# NOTE: numba does not support type hints for njit: https://github.com/python/mypy/issues/16149
|
10 |
+
from numba import njit # type: ignore[attr-defined]
|
11 |
+
from numba.core import types
|
12 |
+
from numba.typed import Dict, List
|
13 |
+
from transformers.tokenization_utils import PreTrainedTokenizer
|
14 |
+
from transformers.utils import logging
|
15 |
+
|
16 |
+
VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.jsonl"}
|
17 |
+
logger = logging.get_logger(__name__)
|
18 |
+
|
19 |
+
INVALID_SCORE = -20000000
|
20 |
+
UNKNOWN_SCORE = -10000000
|
21 |
+
|
22 |
+
TABLE_PIECE_LENGTH = 0
|
23 |
+
TABLE_TOKEN_ID = 1
|
24 |
+
TABLE_SCORE = 2
|
25 |
+
TABLE_PIECE_ID = 3
|
26 |
+
|
27 |
+
PATH_TOKEN_LENGTH = 0
|
28 |
+
PATH_TOKEN_ID = 1
|
29 |
+
PATH_NUM_TOKENS = 2
|
30 |
+
|
31 |
+
|
32 |
+
class AhoCorasick:
|
33 |
+
def __init__(self) -> None:
|
34 |
+
# List of tokens in the vocabulary.
|
35 |
+
self._tokens: list[str]
|
36 |
+
|
37 |
+
# A mapping from a byte code point to a token ID, used for byte fallback.
|
38 |
+
self._bytes: np.ndarray
|
39 |
+
|
40 |
+
# A mapping from a suffix's piece code to a suffix ID.
|
41 |
+
#
|
42 |
+
# Typically, the Aho-Corasick algorithm builds a Trie and adds suffix links between nodes
|
43 |
+
# of the Trie. In this implementation, a suffix ID corresponds to a node in the trie, and
|
44 |
+
# a piece code to an edge (in other words, a pair of a node and the next character).
|
45 |
+
#
|
46 |
+
# A piece code is a 64-bit integer:
|
47 |
+
# - The upper 32 bits store the Unicode code point of the first character.
|
48 |
+
# - The lower 32 bits store the suffix ID of the remaining suffix.
|
49 |
+
#
|
50 |
+
# A suffix ID is an integer indicating the starting position in the _table.
|
51 |
+
self._to_suffix_id: Dict[types.int64, types.int32]
|
52 |
+
|
53 |
+
# Flattened table representing the Trie structure for the Aho-Corasick algorithm.
|
54 |
+
# It stores information including scores for each piece (prefix) within each suffix.
|
55 |
+
# It is flattened for memory efficiency and performance. Suffixes are stored in
|
56 |
+
# lexicographical order of their reversed strings, which improves memory access locality
|
57 |
+
# when exploring new characters starting from the string's end. Pieces within a suffix are
|
58 |
+
# stored in the decreasing order of their lengths.
|
59 |
+
#
|
60 |
+
# Each piece (a prefix fo the suffix) contains four pieces of information:
|
61 |
+
# - TABLE_PIECE_LENGTH: Length of the piece.
|
62 |
+
# - TABLE_TOKEN_ID: Token ID (or -1 if the piece is not a valid token).
|
63 |
+
# - TABLE_SCORE: Score (or INVALID_SCORE if the piece is not a valid token).
|
64 |
+
# - TABLE_PIECE_ID: Piece ID of the suffix.
|
65 |
+
#
|
66 |
+
# Each suffix also includes a sentinel row with a length of 1, a score of UNKNOWN_SCORE,
|
67 |
+
# and a token ID of -1. Sentinel rows are identified by the score being UNKNOWN_SCORE.
|
68 |
+
self._table: np.ndarray
|
69 |
+
|
70 |
+
def build(self, vocab: list[Any]) -> None:
|
71 |
+
self._bytes = np.zeros(256, dtype=np.int32)
|
72 |
+
self._to_suffix_id = Dict.empty(key_type=types.int64, value_type=types.int32)
|
73 |
+
|
74 |
+
# Build suffix_to_score and token_to_token_id.
|
75 |
+
# The suffix_to_score dictionary maps a suffix to its score. It also includes all suffixes
|
76 |
+
# of the token for the Trie structure for the Aho-Corasick algorithm. If a suffix is not a
|
77 |
+
# valid token, its score is set to math.nan.
|
78 |
+
# The token_to_token_id dictionary maps a token to its token ID.
|
79 |
+
suffix_to_score: dict[str, float] = {}
|
80 |
+
token_to_token_id: dict[str, int] = {}
|
81 |
+
self._tokens = []
|
82 |
+
for token_id, row in enumerate(vocab):
|
83 |
+
assert isinstance(row[0], str), row
|
84 |
+
assert isinstance(row[1], (int, float)), row
|
85 |
+
|
86 |
+
token = str(row[0])
|
87 |
+
self._tokens.append(token)
|
88 |
+
token_to_token_id[token] = token_id
|
89 |
+
|
90 |
+
# Special handling for byte tokens.
|
91 |
+
if len(row) > 2 and row[2] == "BYTE":
|
92 |
+
assert len(token) == 6 and token.startswith("<0x") and token.endswith(">"), row[0]
|
93 |
+
self._bytes[int(row[0][3:5], 16)] = token_id
|
94 |
+
continue
|
95 |
+
|
96 |
+
suffix_to_score[token] = float(row[1])
|
97 |
+
# Ensure that all suffixes are included in suffix_to_score.
|
98 |
+
for i in range(1, len(token)):
|
99 |
+
suffix_to_score[token[i:]] = suffix_to_score.get(token[i:], math.nan)
|
100 |
+
|
101 |
+
# Ensure all byte tokens are set.
|
102 |
+
for i in range(256):
|
103 |
+
assert self._bytes[i] != 0, f"Byte token for <0x{i:02X}> is not set."
|
104 |
+
|
105 |
+
# List suffixes in lexicographical order of their reversed strings.
|
106 |
+
suffixes = list(suffix_to_score.keys())
|
107 |
+
suffixes.append("")
|
108 |
+
suffixes.sort(key=lambda x: x[::-1])
|
109 |
+
|
110 |
+
# Build suffix_to_id, which is a mapping from a suffix to a suffix ID, and _to_suffix_id,
|
111 |
+
# which is a mapping from a piece code to a suffix ID.
|
112 |
+
suffix_to_id: dict[str, int] = {}
|
113 |
+
num_pieces = 0
|
114 |
+
for s in suffixes:
|
115 |
+
suffix_to_id[s] = num_pieces
|
116 |
+
if s != "":
|
117 |
+
self._to_suffix_id[ord(s[0]) << 32 | suffix_to_id[s[1:]]] = np.int32(num_pieces)
|
118 |
+
num_pieces += 1 + sum(s[:i] in suffix_to_score for i in range(1, len(s) + 1))
|
119 |
+
assert suffix_to_id[""] == 0, suffix_to_id[""]
|
120 |
+
|
121 |
+
# Build _table, which is a flattened table representing the Trie structure for the Aho-Corasick.
|
122 |
+
self._table = np.zeros((num_pieces, 4), dtype=np.int32)
|
123 |
+
i = 0
|
124 |
+
for suffix in suffixes:
|
125 |
+
# Add all prefixes of the suffix to the table.
|
126 |
+
for piece_length in range(len(suffix), 0, -1):
|
127 |
+
piece = suffix[:piece_length]
|
128 |
+
score = suffix_to_score.get(piece, None)
|
129 |
+
if score is None:
|
130 |
+
continue
|
131 |
+
self._table[i, TABLE_PIECE_LENGTH] = piece_length
|
132 |
+
self._table[i, TABLE_TOKEN_ID] = token_to_token_id.get(piece, -1)
|
133 |
+
self._table[i, TABLE_SCORE] = round(score * 1e4) if math.isfinite(score) else INVALID_SCORE
|
134 |
+
self._table[i, TABLE_PIECE_ID] = suffix_to_id[piece]
|
135 |
+
i += 1
|
136 |
+
|
137 |
+
# Add a sentinel row.
|
138 |
+
self._table[i, TABLE_PIECE_LENGTH] = 1
|
139 |
+
self._table[i, TABLE_TOKEN_ID] = -1
|
140 |
+
self._table[i, TABLE_SCORE] = UNKNOWN_SCORE
|
141 |
+
i += 1
|
142 |
+
assert i == num_pieces, (i, num_pieces)
|
143 |
+
|
144 |
+
@staticmethod
|
145 |
+
@njit
|
146 |
+
def _encode(
|
147 |
+
to_suffix_id: Dict[types.int64, types.int32],
|
148 |
+
table: np.ndarray,
|
149 |
+
bytes: np.ndarray,
|
150 |
+
data: np.ndarray,
|
151 |
+
) -> np.ndarray:
|
152 |
+
# Initialize scores array with a high value and set the score at the end to 0.
|
153 |
+
# This array keeps track of the minimum cost (best score) to encode from each position to the end.
|
154 |
+
scores = np.full((len(data) + 1,), 2**60, dtype=np.int64)
|
155 |
+
scores[-1] = 0
|
156 |
+
|
157 |
+
# Path array to store the best path information.
|
158 |
+
# The path array keeps track of token length, token ID, and number of tokens needed to encode.
|
159 |
+
path = np.zeros((len(data) + 1, 3), dtype=np.int32)
|
160 |
+
|
161 |
+
# Initialize suffix_id to 0, which represents the root of the Trie.
|
162 |
+
suffix_id = 0
|
163 |
+
|
164 |
+
# Process the input data from the end to the beginning.
|
165 |
+
for i in range(len(data) - 1, -1, -1):
|
166 |
+
c = data[i]
|
167 |
+
|
168 |
+
# Find the next suffix ID by iterating the suffix IDs of prefixes of the current suffix.
|
169 |
+
# NOTE: If no suffix ID is found, suffix_id will be set to 0.
|
170 |
+
for p in range(suffix_id, len(table)):
|
171 |
+
suffix_id = to_suffix_id.get(c << 32 | table[p, TABLE_PIECE_ID], np.int32(0))
|
172 |
+
# If a next suffix ID is found or a sentinel row is reached, break the loop.
|
173 |
+
if suffix_id > 0 or table[p, TABLE_SCORE] == UNKNOWN_SCORE:
|
174 |
+
break
|
175 |
+
|
176 |
+
# Update the best path to the current position. If multiple paths have the same score,
|
177 |
+
# this chooses the longest prefix as the best path (table is sorted in the decreasing
|
178 |
+
# order of piece length).
|
179 |
+
for p in range(suffix_id, len(table)):
|
180 |
+
score = table[p, TABLE_SCORE]
|
181 |
+
if score > INVALID_SCORE:
|
182 |
+
piece_length = table[p, TABLE_PIECE_LENGTH]
|
183 |
+
s = scores[i + piece_length] - score
|
184 |
+
if s < scores[i]:
|
185 |
+
scores[i] = s
|
186 |
+
path[i, PATH_TOKEN_LENGTH] = piece_length
|
187 |
+
path[i, PATH_TOKEN_ID] = table[p, TABLE_TOKEN_ID]
|
188 |
+
path[i, PATH_NUM_TOKENS] = path[i + piece_length, PATH_NUM_TOKENS] + 1
|
189 |
+
if score == UNKNOWN_SCORE:
|
190 |
+
# Add number of bytes to represent `c` in UTF-8 (minus 1; 1 is already
|
191 |
+
# added above).
|
192 |
+
path[i, PATH_NUM_TOKENS] += (c >= 0x80) + (c >= 0x800) + (c >= 0x10000)
|
193 |
+
|
194 |
+
# If it reaches a sentinel row, break the loop.
|
195 |
+
if score == UNKNOWN_SCORE:
|
196 |
+
break
|
197 |
+
|
198 |
+
# Decode the best path from the beginning to get the token IDs.
|
199 |
+
pos = 0
|
200 |
+
token_ids = np.zeros(path[0, PATH_NUM_TOKENS], dtype=np.int32)
|
201 |
+
token_pos = 0
|
202 |
+
while pos < len(data):
|
203 |
+
if path[pos, PATH_TOKEN_ID] >= 0:
|
204 |
+
token_ids[token_pos] = path[pos, PATH_TOKEN_ID]
|
205 |
+
token_pos += 1
|
206 |
+
else:
|
207 |
+
# Fall back to byte tokens.
|
208 |
+
c = data[pos]
|
209 |
+
s = 1 + (c >= 0x80) + (c >= 0x800) + (c >= 0x10000)
|
210 |
+
# Add byte tokens representing UTF-8 bytes.
|
211 |
+
for i in range(s):
|
212 |
+
b = c if s == 1 else (0xF00 >> s) & 0xFF if i == 0 else 0x80
|
213 |
+
token_ids[token_pos] = bytes[b | ((c >> (s - i - 1) * 6) & 0x3F)]
|
214 |
+
token_pos += 1
|
215 |
+
|
216 |
+
# Ensure that pos should increase by at least 1.
|
217 |
+
assert path[pos, PATH_TOKEN_LENGTH] > 0, (pos, path[pos])
|
218 |
+
pos += path[pos, PATH_TOKEN_LENGTH]
|
219 |
+
|
220 |
+
return token_ids
|
221 |
+
|
222 |
+
def encode(self, data: str) -> np.ndarray:
|
223 |
+
"""Encodes a string into a sequence of token IDs."""
|
224 |
+
return np.asarray(
|
225 |
+
self._encode(
|
226 |
+
self._to_suffix_id,
|
227 |
+
self._table,
|
228 |
+
self._bytes,
|
229 |
+
# Convert a string into a numpy array of Unicode code points.
|
230 |
+
# NOTE: This skips UTF-32 BOM.
|
231 |
+
np.frombuffer(data.encode("utf-32"), dtype=np.int32)[1:],
|
232 |
+
)
|
233 |
+
)
|
234 |
+
|
235 |
+
def encode_as_tokens(self, data: str) -> list[str]:
|
236 |
+
"""Encodes a string into a sequence of tokens."""
|
237 |
+
return [self._tokens[token_id] for token_id in self.encode(data)]
|
238 |
+
|
239 |
+
|
240 |
+
class PlamoTokenizer(PreTrainedTokenizer): # type: ignore
|
241 |
+
vocab_files_names = VOCAB_FILES_NAMES
|
242 |
+
model_input_names = ["input_ids", "attention_mask"]
|
243 |
+
|
244 |
+
_save_files = [
|
245 |
+
"special_tokens_map.json",
|
246 |
+
"tokenization_plamo.py",
|
247 |
+
"tokenizer.jsonl",
|
248 |
+
"tokenizer_config.json",
|
249 |
+
]
|
250 |
+
|
251 |
+
def __init__(
|
252 |
+
self,
|
253 |
+
vocab_file: str,
|
254 |
+
unk_token: str = "<|plamo:unk|>",
|
255 |
+
bos_token: str = "<|plamo:bos|>",
|
256 |
+
eos_token: str = "<|plamo:eos|>",
|
257 |
+
pad_token: str = "<|plamo:pad|>",
|
258 |
+
cls_token: Optional[str] = None,
|
259 |
+
sep_token: Optional[str] = None,
|
260 |
+
mask_token: Optional[str] = None,
|
261 |
+
clean_up_tokenization_spaces: bool = False,
|
262 |
+
**kwargs: Any,
|
263 |
+
) -> None:
|
264 |
+
"""Tokenizer for PLaMo.
|
265 |
+
|
266 |
+
Args:
|
267 |
+
vocab_file (str): Vocabrary file path.
|
268 |
+
unk_token (str): Unknown token.
|
269 |
+
bos_token (str): Beginning of sentence token.
|
270 |
+
eos_token (str): End of sentence token.
|
271 |
+
pad_token (str): Padding token.
|
272 |
+
cls_token (str):
|
273 |
+
Classification token, to extract a summary of an input sequence leveraging self-attention along the
|
274 |
+
full depth of the model.
|
275 |
+
sep_token (str): Separation token, to separate context and query in an input sequence.
|
276 |
+
mask_token (str): Mask token, to use when training a model with masked-language modeling.
|
277 |
+
clean_up_tokenization_spaces (bool): Whether or not to clean up the tokenization spaces.
|
278 |
+
num_threads (int):
|
279 |
+
Number of threads. This value will be ignored if one of `PLAMO_TOKENIZER_NUM_THREADS` or
|
280 |
+
`RAYON_NUM_THREADS` is set as an environment variable.
|
281 |
+
"""
|
282 |
+
if "add_bos_token" not in kwargs:
|
283 |
+
kwargs["add_bos_token"] = False
|
284 |
+
if "add_eos_token" not in kwargs:
|
285 |
+
kwargs["add_eos_token"] = False
|
286 |
+
self.data: list[Any] = [json.loads(line) for line in open(vocab_file, "r", encoding="utf-8")]
|
287 |
+
self.vocab: dict[str, int] = {v[0]: i for i, v in enumerate(self.data)}
|
288 |
+
self.aho_corasick = AhoCorasick()
|
289 |
+
self.aho_corasick.build(self.data)
|
290 |
+
self.vocab_file = vocab_file
|
291 |
+
self.add_bos_token = kwargs["add_bos_token"]
|
292 |
+
self.add_eos_token = kwargs["add_eos_token"]
|
293 |
+
|
294 |
+
super().__init__(
|
295 |
+
vocab_file=vocab_file,
|
296 |
+
unk_token=unk_token,
|
297 |
+
bos_token=bos_token,
|
298 |
+
eos_token=eos_token,
|
299 |
+
pad_token=pad_token,
|
300 |
+
cls_token=cls_token,
|
301 |
+
sep_token=sep_token,
|
302 |
+
mask_token=mask_token,
|
303 |
+
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
304 |
+
**kwargs,
|
305 |
+
)
|
306 |
+
|
307 |
+
# the functions below are copied from hf transformers LlamaTokenizer's implementation to fix the behaviour of the tokenizer
|
308 |
+
# https://github.com/huggingface/transformers/blob/v4.30.2/src/transformers/models/llama/tokenization_llama.py
|
309 |
+
|
310 |
+
def __getstate__(self) -> dict[str, Any]:
|
311 |
+
state = self.__dict__.copy()
|
312 |
+
state["aho_corasick"] = None
|
313 |
+
return state
|
314 |
+
|
315 |
+
def __setstate__(self, d: dict[str, Any]) -> None:
|
316 |
+
self.__dict__ = d
|
317 |
+
self.aho_corasick = AhoCorasick()
|
318 |
+
self.aho_corasick.build(self.data)
|
319 |
+
|
320 |
+
@property
|
321 |
+
def vocab_size(self) -> Any:
|
322 |
+
"""Returns vocab size"""
|
323 |
+
return len(self.data)
|
324 |
+
|
325 |
+
def token_to_score(self, token: str) -> Optional[float]:
|
326 |
+
"""Returns score of the token"""
|
327 |
+
token_id = self.vocab.get(token, None)
|
328 |
+
return None if token_id is None else self.data[token_id][1]
|
329 |
+
|
330 |
+
def get_vocab(self) -> dict[str, int]:
|
331 |
+
"""Returns vocab as a dict"""
|
332 |
+
vocab = self.vocab.copy()
|
333 |
+
vocab.update(self.added_tokens_encoder)
|
334 |
+
return vocab
|
335 |
+
|
336 |
+
def convert_tokens_to_string(self, tokens: List[str]) -> str:
|
337 |
+
"""Converts a sequence of tokens (string) in a single string."""
|
338 |
+
return b"".join(
|
339 |
+
[bytes([int(t[3:5], 16)]) if t.startswith("<0x") else t.encode("utf-8") for t in tokens]
|
340 |
+
).decode("utf-8", errors="replace")
|
341 |
+
|
342 |
+
def _tokenize(self, text: str) -> Any:
|
343 |
+
"""Returns a tokenized string."""
|
344 |
+
return self.aho_corasick.encode_as_tokens(text)
|
345 |
+
|
346 |
+
def _convert_token_to_id(self, token: str) -> Any:
|
347 |
+
"""Converts a token (str) in an id using the vocab."""
|
348 |
+
return self.vocab.get(token, 0)
|
349 |
+
|
350 |
+
def _convert_id_to_token(self, index: int) -> Any:
|
351 |
+
"""Converts an index (integer) in a token (str) using the vocab."""
|
352 |
+
return self.data[index][0]
|
353 |
+
|
354 |
+
def build_inputs_with_special_tokens(
|
355 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
356 |
+
) -> List[int]:
|
357 |
+
bos_token_id = [self.bos_token_id] if self.add_bos_token else []
|
358 |
+
eos_token_id = [self.eos_token_id] if self.add_eos_token else []
|
359 |
+
|
360 |
+
output = bos_token_id + token_ids_0 + eos_token_id
|
361 |
+
|
362 |
+
if token_ids_1 is not None:
|
363 |
+
output = output + bos_token_id + token_ids_1 + eos_token_id
|
364 |
+
|
365 |
+
return output
|
366 |
+
|
367 |
+
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
368 |
+
"""
|
369 |
+
Save the vocabulary and special tokens file to a directory.
|
370 |
+
|
371 |
+
Args:
|
372 |
+
save_directory (`str`):
|
373 |
+
The directory in which to save the vocabulary.
|
374 |
+
|
375 |
+
Returns:
|
376 |
+
`Tuple(str)`: Paths to the files saved.
|
377 |
+
"""
|
378 |
+
if not os.path.isdir(save_directory):
|
379 |
+
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
|
380 |
+
return ("",)
|
381 |
+
out_vocab_file = os.path.join(
|
382 |
+
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
|
383 |
+
)
|
384 |
+
|
385 |
+
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
|
386 |
+
copyfile(self.vocab_file, out_vocab_file)
|
387 |
+
elif not os.path.isfile(self.vocab_file):
|
388 |
+
with open(out_vocab_file, "w") as f:
|
389 |
+
for token in self.data:
|
390 |
+
print(json.dumps(token, ensure_ascii=False), file=f)
|
391 |
+
|
392 |
+
return (out_vocab_file,)
|
tokenizer.jsonl
ADDED
The diff for this file is too large to render.
See raw diff
|
|
tokenizer_config.json
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"add_bos_token": true,
|
3 |
+
"add_eos_token": false,
|
4 |
+
"added_tokens_decoder": {
|
5 |
+
"0": {
|
6 |
+
"content": "<|plamo:unk|>",
|
7 |
+
"lstrip": false,
|
8 |
+
"normalized": false,
|
9 |
+
"rstrip": false,
|
10 |
+
"single_word": false,
|
11 |
+
"special": true
|
12 |
+
},
|
13 |
+
"1": {
|
14 |
+
"content": "<|plamo:bos|>",
|
15 |
+
"lstrip": false,
|
16 |
+
"normalized": false,
|
17 |
+
"rstrip": false,
|
18 |
+
"single_word": false,
|
19 |
+
"special": true
|
20 |
+
},
|
21 |
+
"2": {
|
22 |
+
"content": "<|plamo:eos|>",
|
23 |
+
"lstrip": false,
|
24 |
+
"normalized": false,
|
25 |
+
"rstrip": false,
|
26 |
+
"single_word": false,
|
27 |
+
"special": true
|
28 |
+
},
|
29 |
+
"3": {
|
30 |
+
"content": "<|plamo:pad|>",
|
31 |
+
"lstrip": false,
|
32 |
+
"normalized": false,
|
33 |
+
"rstrip": false,
|
34 |
+
"single_word": false,
|
35 |
+
"special": true
|
36 |
+
}
|
37 |
+
},
|
38 |
+
"auto_map": {
|
39 |
+
"AutoTokenizer": [
|
40 |
+
"tokenization_plamo.PlamoTokenizer",
|
41 |
+
null
|
42 |
+
]
|
43 |
+
},
|
44 |
+
"bos_token": "<|plamo:bos|>",
|
45 |
+
"clean_up_tokenization_spaces": false,
|
46 |
+
"cls_token": null,
|
47 |
+
"eos_token": "<|plamo:eos|>",
|
48 |
+
"local_file_only": true,
|
49 |
+
"mask_token": null,
|
50 |
+
"model_max_length": 1000000000000000019884624838656,
|
51 |
+
"pad_token": "<|plamo:pad|>",
|
52 |
+
"sep_token": null,
|
53 |
+
"tokenizer_class": "PlamoTokenizer",
|
54 |
+
"unk_token": "<|plamo:unk|>"
|
55 |
+
}
|