mitmul commited on
Commit
ce95292
·
verified ·
1 Parent(s): 103a96d

Add files using upload-large-folder tool

Browse files
README.md ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ language:
4
+ - en
5
+ - ja
6
+ pipeline_tag: text-generation
7
+ library_name: transformers
8
+ base_model: pfnet/plamo-2-1b
9
+ tags:
10
+ - mlx
11
+ ---
12
+
13
+ # mlx-community/plamo-2-1b-bf16
14
+
15
+ The Model [mlx-community/plamo-2-1b-bf16](https://huggingface.co/mlx-community/plamo-2-1b-bf16) was
16
+ converted to MLX format from [pfnet/plamo-2-1b](https://huggingface.co/pfnet/plamo-2-1b)
17
+ using mlx-lm version **0.21.5**.
18
+
19
+ ## Use with mlx
20
+
21
+ ```bash
22
+ pip install mlx-lm
23
+ ```
24
+
25
+ ```python
26
+ from mlx_lm import load, generate
27
+
28
+ model, tokenizer = load("mlx-community/plamo-2-1b-bf16")
29
+
30
+ prompt = "hello"
31
+
32
+ if tokenizer.chat_template is not None:
33
+ messages = [{"role": "user", "content": prompt}]
34
+ prompt = tokenizer.apply_chat_template(
35
+ messages, add_generation_prompt=True
36
+ )
37
+
38
+ response = generate(model, tokenizer, prompt=prompt, verbose=True)
39
+ ```
config.json ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "PlamoForCausalLM"
4
+ ],
5
+ "attention_window_size": 2048,
6
+ "auto_map": {
7
+ "AutoConfig": "modeling_plamo.PlamoConfig",
8
+ "AutoModelForCausalLM": "modeling_plamo.PlamoForCausalLM"
9
+ },
10
+ "bos_token_id": 1,
11
+ "capacity_factor": 1.0,
12
+ "eos_token_id": 2,
13
+ "eval_attention_n_bit": null,
14
+ "eval_mlp_n_bit": null,
15
+ "expert_dropout": 0.0,
16
+ "fp8_accum_dtype": "bfloat16",
17
+ "group_size": 1024,
18
+ "hidden_size": 2048,
19
+ "hidden_size_per_head": 128,
20
+ "image_feature_size": null,
21
+ "image_proj_type": "linear",
22
+ "image_token_id": null,
23
+ "intermediate_size": 8192,
24
+ "k_expert": null,
25
+ "linear_type": "fp8",
26
+ "mamba_chunk_size": 256,
27
+ "mamba_d_conv": 4,
28
+ "mamba_d_state": 64,
29
+ "mamba_enabled": true,
30
+ "mamba_num_heads": 32,
31
+ "mamba_step": 2,
32
+ "max_position_embeddings": 10485760,
33
+ "model_type": "plamo2",
34
+ "n_expert": null,
35
+ "num_attention_heads": 16,
36
+ "num_hidden_layers": 16,
37
+ "num_key_value_heads": 1,
38
+ "rms_norm_eps": 1e-06,
39
+ "shared_intermediate_size": null,
40
+ "sliding_window": 2048,
41
+ "sparse_intermediate_size": null,
42
+ "sparse_step": null,
43
+ "tokenizer_class": "PlamoTokenizer",
44
+ "torch_dtype": "float32",
45
+ "transformers_version": "4.44.2",
46
+ "use_cache": true,
47
+ "use_predefined_initial_state": false,
48
+ "vocab_size": 100000
49
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a3a5eacef896d4ebe5ce590df7fbfc8dc2c4f3dc4886e2ae01e7a609dd7bd827
3
+ size 2582909060
model.safetensors.index.json ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 2582883840
4
+ },
5
+ "weight_map": {
6
+ "model.embed_tokens.weight": "model.safetensors",
7
+ "model.layers.layers.0.mixer.A_log": "model.safetensors",
8
+ "model.layers.layers.0.mixer.B_norm_weight": "model.safetensors",
9
+ "model.layers.layers.0.mixer.C_norm_weight": "model.safetensors",
10
+ "model.layers.layers.0.mixer.D": "model.safetensors",
11
+ "model.layers.layers.0.mixer.bcdt_proj.weight": "model.safetensors",
12
+ "model.layers.layers.0.mixer.conv1d.weight": "model.safetensors",
13
+ "model.layers.layers.0.mixer.dt_bias": "model.safetensors",
14
+ "model.layers.layers.0.mixer.dt_norm_weight": "model.safetensors",
15
+ "model.layers.layers.0.mixer.dt_proj.weight": "model.safetensors",
16
+ "model.layers.layers.0.mixer.in_proj.weight": "model.safetensors",
17
+ "model.layers.layers.0.mixer.out_proj.weight": "model.safetensors",
18
+ "model.layers.layers.0.mlp.down_proj.weight": "model.safetensors",
19
+ "model.layers.layers.0.mlp.gate_up_proj.weight": "model.safetensors",
20
+ "model.layers.layers.0.post_mixer_norm.weight": "model.safetensors",
21
+ "model.layers.layers.0.post_mlp_norm.weight": "model.safetensors",
22
+ "model.layers.layers.0.pre_mixer_norm.weight": "model.safetensors",
23
+ "model.layers.layers.0.pre_mlp_norm.weight": "model.safetensors",
24
+ "model.layers.layers.1.mixer.k_weight": "model.safetensors",
25
+ "model.layers.layers.1.mixer.o_proj.weight": "model.safetensors",
26
+ "model.layers.layers.1.mixer.q_weight": "model.safetensors",
27
+ "model.layers.layers.1.mixer.qkv_proj.weight": "model.safetensors",
28
+ "model.layers.layers.1.mlp.down_proj.weight": "model.safetensors",
29
+ "model.layers.layers.1.mlp.gate_up_proj.weight": "model.safetensors",
30
+ "model.layers.layers.1.post_mixer_norm.weight": "model.safetensors",
31
+ "model.layers.layers.1.post_mlp_norm.weight": "model.safetensors",
32
+ "model.layers.layers.1.pre_mixer_norm.weight": "model.safetensors",
33
+ "model.layers.layers.1.pre_mlp_norm.weight": "model.safetensors",
34
+ "model.layers.layers.10.mixer.A_log": "model.safetensors",
35
+ "model.layers.layers.10.mixer.B_norm_weight": "model.safetensors",
36
+ "model.layers.layers.10.mixer.C_norm_weight": "model.safetensors",
37
+ "model.layers.layers.10.mixer.D": "model.safetensors",
38
+ "model.layers.layers.10.mixer.bcdt_proj.weight": "model.safetensors",
39
+ "model.layers.layers.10.mixer.conv1d.weight": "model.safetensors",
40
+ "model.layers.layers.10.mixer.dt_bias": "model.safetensors",
41
+ "model.layers.layers.10.mixer.dt_norm_weight": "model.safetensors",
42
+ "model.layers.layers.10.mixer.dt_proj.weight": "model.safetensors",
43
+ "model.layers.layers.10.mixer.in_proj.weight": "model.safetensors",
44
+ "model.layers.layers.10.mixer.out_proj.weight": "model.safetensors",
45
+ "model.layers.layers.10.mlp.down_proj.weight": "model.safetensors",
46
+ "model.layers.layers.10.mlp.gate_up_proj.weight": "model.safetensors",
47
+ "model.layers.layers.10.post_mixer_norm.weight": "model.safetensors",
48
+ "model.layers.layers.10.post_mlp_norm.weight": "model.safetensors",
49
+ "model.layers.layers.10.pre_mixer_norm.weight": "model.safetensors",
50
+ "model.layers.layers.10.pre_mlp_norm.weight": "model.safetensors",
51
+ "model.layers.layers.11.mixer.k_weight": "model.safetensors",
52
+ "model.layers.layers.11.mixer.o_proj.weight": "model.safetensors",
53
+ "model.layers.layers.11.mixer.q_weight": "model.safetensors",
54
+ "model.layers.layers.11.mixer.qkv_proj.weight": "model.safetensors",
55
+ "model.layers.layers.11.mlp.down_proj.weight": "model.safetensors",
56
+ "model.layers.layers.11.mlp.gate_up_proj.weight": "model.safetensors",
57
+ "model.layers.layers.11.post_mixer_norm.weight": "model.safetensors",
58
+ "model.layers.layers.11.post_mlp_norm.weight": "model.safetensors",
59
+ "model.layers.layers.11.pre_mixer_norm.weight": "model.safetensors",
60
+ "model.layers.layers.11.pre_mlp_norm.weight": "model.safetensors",
61
+ "model.layers.layers.12.mixer.A_log": "model.safetensors",
62
+ "model.layers.layers.12.mixer.B_norm_weight": "model.safetensors",
63
+ "model.layers.layers.12.mixer.C_norm_weight": "model.safetensors",
64
+ "model.layers.layers.12.mixer.D": "model.safetensors",
65
+ "model.layers.layers.12.mixer.bcdt_proj.weight": "model.safetensors",
66
+ "model.layers.layers.12.mixer.conv1d.weight": "model.safetensors",
67
+ "model.layers.layers.12.mixer.dt_bias": "model.safetensors",
68
+ "model.layers.layers.12.mixer.dt_norm_weight": "model.safetensors",
69
+ "model.layers.layers.12.mixer.dt_proj.weight": "model.safetensors",
70
+ "model.layers.layers.12.mixer.in_proj.weight": "model.safetensors",
71
+ "model.layers.layers.12.mixer.out_proj.weight": "model.safetensors",
72
+ "model.layers.layers.12.mlp.down_proj.weight": "model.safetensors",
73
+ "model.layers.layers.12.mlp.gate_up_proj.weight": "model.safetensors",
74
+ "model.layers.layers.12.post_mixer_norm.weight": "model.safetensors",
75
+ "model.layers.layers.12.post_mlp_norm.weight": "model.safetensors",
76
+ "model.layers.layers.12.pre_mixer_norm.weight": "model.safetensors",
77
+ "model.layers.layers.12.pre_mlp_norm.weight": "model.safetensors",
78
+ "model.layers.layers.13.mixer.k_weight": "model.safetensors",
79
+ "model.layers.layers.13.mixer.o_proj.weight": "model.safetensors",
80
+ "model.layers.layers.13.mixer.q_weight": "model.safetensors",
81
+ "model.layers.layers.13.mixer.qkv_proj.weight": "model.safetensors",
82
+ "model.layers.layers.13.mlp.down_proj.weight": "model.safetensors",
83
+ "model.layers.layers.13.mlp.gate_up_proj.weight": "model.safetensors",
84
+ "model.layers.layers.13.post_mixer_norm.weight": "model.safetensors",
85
+ "model.layers.layers.13.post_mlp_norm.weight": "model.safetensors",
86
+ "model.layers.layers.13.pre_mixer_norm.weight": "model.safetensors",
87
+ "model.layers.layers.13.pre_mlp_norm.weight": "model.safetensors",
88
+ "model.layers.layers.14.mixer.A_log": "model.safetensors",
89
+ "model.layers.layers.14.mixer.B_norm_weight": "model.safetensors",
90
+ "model.layers.layers.14.mixer.C_norm_weight": "model.safetensors",
91
+ "model.layers.layers.14.mixer.D": "model.safetensors",
92
+ "model.layers.layers.14.mixer.bcdt_proj.weight": "model.safetensors",
93
+ "model.layers.layers.14.mixer.conv1d.weight": "model.safetensors",
94
+ "model.layers.layers.14.mixer.dt_bias": "model.safetensors",
95
+ "model.layers.layers.14.mixer.dt_norm_weight": "model.safetensors",
96
+ "model.layers.layers.14.mixer.dt_proj.weight": "model.safetensors",
97
+ "model.layers.layers.14.mixer.in_proj.weight": "model.safetensors",
98
+ "model.layers.layers.14.mixer.out_proj.weight": "model.safetensors",
99
+ "model.layers.layers.14.mlp.down_proj.weight": "model.safetensors",
100
+ "model.layers.layers.14.mlp.gate_up_proj.weight": "model.safetensors",
101
+ "model.layers.layers.14.post_mixer_norm.weight": "model.safetensors",
102
+ "model.layers.layers.14.post_mlp_norm.weight": "model.safetensors",
103
+ "model.layers.layers.14.pre_mixer_norm.weight": "model.safetensors",
104
+ "model.layers.layers.14.pre_mlp_norm.weight": "model.safetensors",
105
+ "model.layers.layers.15.mixer.k_weight": "model.safetensors",
106
+ "model.layers.layers.15.mixer.o_proj.weight": "model.safetensors",
107
+ "model.layers.layers.15.mixer.q_weight": "model.safetensors",
108
+ "model.layers.layers.15.mixer.qkv_proj.weight": "model.safetensors",
109
+ "model.layers.layers.15.mlp.down_proj.weight": "model.safetensors",
110
+ "model.layers.layers.15.mlp.gate_up_proj.weight": "model.safetensors",
111
+ "model.layers.layers.15.post_mixer_norm.weight": "model.safetensors",
112
+ "model.layers.layers.15.post_mlp_norm.weight": "model.safetensors",
113
+ "model.layers.layers.15.pre_mixer_norm.weight": "model.safetensors",
114
+ "model.layers.layers.15.pre_mlp_norm.weight": "model.safetensors",
115
+ "model.layers.layers.2.mixer.A_log": "model.safetensors",
116
+ "model.layers.layers.2.mixer.B_norm_weight": "model.safetensors",
117
+ "model.layers.layers.2.mixer.C_norm_weight": "model.safetensors",
118
+ "model.layers.layers.2.mixer.D": "model.safetensors",
119
+ "model.layers.layers.2.mixer.bcdt_proj.weight": "model.safetensors",
120
+ "model.layers.layers.2.mixer.conv1d.weight": "model.safetensors",
121
+ "model.layers.layers.2.mixer.dt_bias": "model.safetensors",
122
+ "model.layers.layers.2.mixer.dt_norm_weight": "model.safetensors",
123
+ "model.layers.layers.2.mixer.dt_proj.weight": "model.safetensors",
124
+ "model.layers.layers.2.mixer.in_proj.weight": "model.safetensors",
125
+ "model.layers.layers.2.mixer.out_proj.weight": "model.safetensors",
126
+ "model.layers.layers.2.mlp.down_proj.weight": "model.safetensors",
127
+ "model.layers.layers.2.mlp.gate_up_proj.weight": "model.safetensors",
128
+ "model.layers.layers.2.post_mixer_norm.weight": "model.safetensors",
129
+ "model.layers.layers.2.post_mlp_norm.weight": "model.safetensors",
130
+ "model.layers.layers.2.pre_mixer_norm.weight": "model.safetensors",
131
+ "model.layers.layers.2.pre_mlp_norm.weight": "model.safetensors",
132
+ "model.layers.layers.3.mixer.k_weight": "model.safetensors",
133
+ "model.layers.layers.3.mixer.o_proj.weight": "model.safetensors",
134
+ "model.layers.layers.3.mixer.q_weight": "model.safetensors",
135
+ "model.layers.layers.3.mixer.qkv_proj.weight": "model.safetensors",
136
+ "model.layers.layers.3.mlp.down_proj.weight": "model.safetensors",
137
+ "model.layers.layers.3.mlp.gate_up_proj.weight": "model.safetensors",
138
+ "model.layers.layers.3.post_mixer_norm.weight": "model.safetensors",
139
+ "model.layers.layers.3.post_mlp_norm.weight": "model.safetensors",
140
+ "model.layers.layers.3.pre_mixer_norm.weight": "model.safetensors",
141
+ "model.layers.layers.3.pre_mlp_norm.weight": "model.safetensors",
142
+ "model.layers.layers.4.mixer.A_log": "model.safetensors",
143
+ "model.layers.layers.4.mixer.B_norm_weight": "model.safetensors",
144
+ "model.layers.layers.4.mixer.C_norm_weight": "model.safetensors",
145
+ "model.layers.layers.4.mixer.D": "model.safetensors",
146
+ "model.layers.layers.4.mixer.bcdt_proj.weight": "model.safetensors",
147
+ "model.layers.layers.4.mixer.conv1d.weight": "model.safetensors",
148
+ "model.layers.layers.4.mixer.dt_bias": "model.safetensors",
149
+ "model.layers.layers.4.mixer.dt_norm_weight": "model.safetensors",
150
+ "model.layers.layers.4.mixer.dt_proj.weight": "model.safetensors",
151
+ "model.layers.layers.4.mixer.in_proj.weight": "model.safetensors",
152
+ "model.layers.layers.4.mixer.out_proj.weight": "model.safetensors",
153
+ "model.layers.layers.4.mlp.down_proj.weight": "model.safetensors",
154
+ "model.layers.layers.4.mlp.gate_up_proj.weight": "model.safetensors",
155
+ "model.layers.layers.4.post_mixer_norm.weight": "model.safetensors",
156
+ "model.layers.layers.4.post_mlp_norm.weight": "model.safetensors",
157
+ "model.layers.layers.4.pre_mixer_norm.weight": "model.safetensors",
158
+ "model.layers.layers.4.pre_mlp_norm.weight": "model.safetensors",
159
+ "model.layers.layers.5.mixer.k_weight": "model.safetensors",
160
+ "model.layers.layers.5.mixer.o_proj.weight": "model.safetensors",
161
+ "model.layers.layers.5.mixer.q_weight": "model.safetensors",
162
+ "model.layers.layers.5.mixer.qkv_proj.weight": "model.safetensors",
163
+ "model.layers.layers.5.mlp.down_proj.weight": "model.safetensors",
164
+ "model.layers.layers.5.mlp.gate_up_proj.weight": "model.safetensors",
165
+ "model.layers.layers.5.post_mixer_norm.weight": "model.safetensors",
166
+ "model.layers.layers.5.post_mlp_norm.weight": "model.safetensors",
167
+ "model.layers.layers.5.pre_mixer_norm.weight": "model.safetensors",
168
+ "model.layers.layers.5.pre_mlp_norm.weight": "model.safetensors",
169
+ "model.layers.layers.6.mixer.A_log": "model.safetensors",
170
+ "model.layers.layers.6.mixer.B_norm_weight": "model.safetensors",
171
+ "model.layers.layers.6.mixer.C_norm_weight": "model.safetensors",
172
+ "model.layers.layers.6.mixer.D": "model.safetensors",
173
+ "model.layers.layers.6.mixer.bcdt_proj.weight": "model.safetensors",
174
+ "model.layers.layers.6.mixer.conv1d.weight": "model.safetensors",
175
+ "model.layers.layers.6.mixer.dt_bias": "model.safetensors",
176
+ "model.layers.layers.6.mixer.dt_norm_weight": "model.safetensors",
177
+ "model.layers.layers.6.mixer.dt_proj.weight": "model.safetensors",
178
+ "model.layers.layers.6.mixer.in_proj.weight": "model.safetensors",
179
+ "model.layers.layers.6.mixer.out_proj.weight": "model.safetensors",
180
+ "model.layers.layers.6.mlp.down_proj.weight": "model.safetensors",
181
+ "model.layers.layers.6.mlp.gate_up_proj.weight": "model.safetensors",
182
+ "model.layers.layers.6.post_mixer_norm.weight": "model.safetensors",
183
+ "model.layers.layers.6.post_mlp_norm.weight": "model.safetensors",
184
+ "model.layers.layers.6.pre_mixer_norm.weight": "model.safetensors",
185
+ "model.layers.layers.6.pre_mlp_norm.weight": "model.safetensors",
186
+ "model.layers.layers.7.mixer.k_weight": "model.safetensors",
187
+ "model.layers.layers.7.mixer.o_proj.weight": "model.safetensors",
188
+ "model.layers.layers.7.mixer.q_weight": "model.safetensors",
189
+ "model.layers.layers.7.mixer.qkv_proj.weight": "model.safetensors",
190
+ "model.layers.layers.7.mlp.down_proj.weight": "model.safetensors",
191
+ "model.layers.layers.7.mlp.gate_up_proj.weight": "model.safetensors",
192
+ "model.layers.layers.7.post_mixer_norm.weight": "model.safetensors",
193
+ "model.layers.layers.7.post_mlp_norm.weight": "model.safetensors",
194
+ "model.layers.layers.7.pre_mixer_norm.weight": "model.safetensors",
195
+ "model.layers.layers.7.pre_mlp_norm.weight": "model.safetensors",
196
+ "model.layers.layers.8.mixer.A_log": "model.safetensors",
197
+ "model.layers.layers.8.mixer.B_norm_weight": "model.safetensors",
198
+ "model.layers.layers.8.mixer.C_norm_weight": "model.safetensors",
199
+ "model.layers.layers.8.mixer.D": "model.safetensors",
200
+ "model.layers.layers.8.mixer.bcdt_proj.weight": "model.safetensors",
201
+ "model.layers.layers.8.mixer.conv1d.weight": "model.safetensors",
202
+ "model.layers.layers.8.mixer.dt_bias": "model.safetensors",
203
+ "model.layers.layers.8.mixer.dt_norm_weight": "model.safetensors",
204
+ "model.layers.layers.8.mixer.dt_proj.weight": "model.safetensors",
205
+ "model.layers.layers.8.mixer.in_proj.weight": "model.safetensors",
206
+ "model.layers.layers.8.mixer.out_proj.weight": "model.safetensors",
207
+ "model.layers.layers.8.mlp.down_proj.weight": "model.safetensors",
208
+ "model.layers.layers.8.mlp.gate_up_proj.weight": "model.safetensors",
209
+ "model.layers.layers.8.post_mixer_norm.weight": "model.safetensors",
210
+ "model.layers.layers.8.post_mlp_norm.weight": "model.safetensors",
211
+ "model.layers.layers.8.pre_mixer_norm.weight": "model.safetensors",
212
+ "model.layers.layers.8.pre_mlp_norm.weight": "model.safetensors",
213
+ "model.layers.layers.9.mixer.k_weight": "model.safetensors",
214
+ "model.layers.layers.9.mixer.o_proj.weight": "model.safetensors",
215
+ "model.layers.layers.9.mixer.q_weight": "model.safetensors",
216
+ "model.layers.layers.9.mixer.qkv_proj.weight": "model.safetensors",
217
+ "model.layers.layers.9.mlp.down_proj.weight": "model.safetensors",
218
+ "model.layers.layers.9.mlp.gate_up_proj.weight": "model.safetensors",
219
+ "model.layers.layers.9.post_mixer_norm.weight": "model.safetensors",
220
+ "model.layers.layers.9.post_mlp_norm.weight": "model.safetensors",
221
+ "model.layers.layers.9.pre_mixer_norm.weight": "model.safetensors",
222
+ "model.layers.layers.9.pre_mlp_norm.weight": "model.safetensors",
223
+ "model.norm.weight": "model.safetensors"
224
+ }
225
+ }
modeling_plamo.py ADDED
@@ -0,0 +1,1699 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import enum
2
+ import math
3
+ import warnings
4
+ from typing import Any, Dict, List, Literal, NamedTuple, Optional, Tuple, Union
5
+
6
+ try:
7
+ # It is difficult to install mamba_ssm in login node because
8
+ # it requires GPU for installation
9
+ import mamba_ssm
10
+ except ModuleNotFoundError:
11
+ warnings.warn("mamba_ssm could not be imported", stacklevel=2)
12
+ try:
13
+ # It is difficult to install causal_conv1d in login node because
14
+ # it requires GPU for installation
15
+ import causal_conv1d.causal_conv1d_interface as causal_conv1d
16
+ except ModuleNotFoundError:
17
+ warnings.warn("causal_conv1d could not be imported", stacklevel=2)
18
+ import torch
19
+ from torch import nn
20
+ from torch.nn import functional as F
21
+ from transformers import PretrainedConfig, PreTrainedModel
22
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
23
+
24
+
25
+ def _is_first_token(mask: torch.Tensor) -> torch.Tensor:
26
+ assert mask.dtype == torch.bool
27
+ B, Nh, q_len, kv_len = mask.shape
28
+ mask = mask[:, :, :, -q_len:]
29
+ cont = q_len != kv_len
30
+ v = False if cont else True
31
+ out = torch.logical_not(torch.diagonal(mask, offset=-1, dim1=-2, dim2=-1).bool())
32
+ out = torch.cat(
33
+ [
34
+ torch.full(size=(B, Nh, 1), dtype=torch.bool, device=out.device, fill_value=v),
35
+ out,
36
+ ],
37
+ dim=-1,
38
+ )
39
+ return out
40
+
41
+
42
+ def _swiglu(h: torch.Tensor) -> torch.Tensor:
43
+ h0, h1 = h.chunk(2, dim=-1)
44
+ return torch.nn.functional.silu(h0) * h1
45
+
46
+
47
+ class RotaryEmbedding(torch.nn.Module):
48
+ def __init__(
49
+ self, dim: int, max_position_embeddings: int = 2048, base: int = 10000, device: Optional[torch.device] = None
50
+ ) -> None:
51
+ super().__init__()
52
+
53
+ self.dim = dim
54
+ self.max_position_embeddings = max_position_embeddings
55
+ self.base = base
56
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
57
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
58
+
59
+ # Build here to make `torch.jit.trace` work.
60
+ self._set_cos_sin_cache(
61
+ seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
62
+ )
63
+
64
+ def _set_cos_sin_cache(self, seq_len: int, device: Any, dtype: Any) -> None:
65
+ self.max_seq_len_cached = seq_len
66
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) # type: ignore
67
+
68
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
69
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
70
+ emb = torch.cat((freqs, freqs), dim=-1)
71
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
72
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
73
+
74
+ def forward(self, x: torch.Tensor, seq_len: int) -> Tuple[torch.Tensor, torch.Tensor]:
75
+ # x: [bs, num_attention_heads, seq_len, head_size]
76
+ if seq_len > self.max_seq_len_cached:
77
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
78
+
79
+ return (
80
+ self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), # type: ignore
81
+ self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), # type: ignore
82
+ )
83
+
84
+
85
+ def _rotate_half(x: torch.Tensor) -> torch.Tensor:
86
+ """Rotates half the hidden dims of the input."""
87
+ x1 = x[..., : x.shape[-1] // 2]
88
+ x2 = x[..., x.shape[-1] // 2 :]
89
+ return torch.cat((-x2, x1), dim=-1)
90
+
91
+
92
+ def _rotary_pos_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, position_ids: torch.Tensor) -> torch.Tensor:
93
+ # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
94
+ cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
95
+ sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
96
+ cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
97
+ sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
98
+ x_embed = (x * cos) + (_rotate_half(x) * sin)
99
+ return x_embed
100
+
101
+
102
+ class LinearType(str, enum.Enum):
103
+ Normal = "normal"
104
+ Fp8 = "fp8"
105
+ Fp8Retain = "fp8-retain"
106
+
107
+
108
+ class PlamoConfig(PretrainedConfig): # type: ignore
109
+ model_type: str = "plamo"
110
+
111
+ def __init__(
112
+ self,
113
+ hidden_size: int = 4096,
114
+ num_hidden_layers: int = 32,
115
+ rms_norm_eps: float = 1e-6,
116
+ tie_word_embeddings: bool = True,
117
+ # Attention
118
+ num_attention_heads: int = 32,
119
+ num_key_value_heads: int = 4,
120
+ hidden_size_per_head: int = 128,
121
+ max_position_embeddings: int = 2048,
122
+ attention_window_size: int = 2048,
123
+ full_attention_idx: list[int] | None = None,
124
+ # Mamba
125
+ mamba_d_state: int = 64,
126
+ mamba_d_conv: int = 4,
127
+ mamba_num_heads: int = 64,
128
+ mamba_step: int = 2,
129
+ mamba_chunk_size: int = 256,
130
+ mamba_enabled: bool = True,
131
+ # MLP
132
+ intermediate_size: int = 13312,
133
+ # Tokenizer
134
+ vocab_size: int = 32000,
135
+ tokenizer_class: str = "PlamoTokenizer",
136
+ pad_token_id: Optional[int] = None,
137
+ bos_token_id: int = 1,
138
+ eos_token_id: int = 2,
139
+ # Multimodal
140
+ image_token_id: Optional[int] = None,
141
+ image_feature_size: Optional[int] = None,
142
+ image_proj_type: Literal["linear", "mlp"] = "linear",
143
+ # FP8
144
+ linear_type: LinearType = LinearType.Normal,
145
+ fp8_accum_dtype: Optional[str] = None,
146
+ # Evaluation
147
+ eval_attention_n_bit: Optional[int] = None,
148
+ eval_mlp_n_bit: Optional[int] = None,
149
+ use_cache: bool = True,
150
+ **kwargs: Any,
151
+ ) -> None:
152
+ # max_position_embeddings is often used to determine the max length during inference,
153
+ # but samba should have extrapolation abilities
154
+ self.max_position_embeddings = max(10 * 1024 * 1024, max_position_embeddings)
155
+ self.hidden_size = hidden_size
156
+ self.rms_norm_eps = rms_norm_eps
157
+
158
+ self.num_hidden_layers = num_hidden_layers
159
+ self.num_attention_heads = num_attention_heads
160
+ self.hidden_size_per_head = hidden_size_per_head
161
+ self.num_key_value_heads = num_key_value_heads
162
+ self.attention_window_size = attention_window_size
163
+ self.full_attention_idx = full_attention_idx if full_attention_idx is not None else []
164
+
165
+ self.mamba_d_state = mamba_d_state
166
+ self.mamba_d_conv = mamba_d_conv
167
+ self.mamba_num_heads = mamba_num_heads
168
+ self.mamba_step = mamba_step
169
+ self.mamba_chunk_size = mamba_chunk_size
170
+ self.mamba_enabled = mamba_enabled
171
+
172
+ self.intermediate_size = intermediate_size
173
+
174
+ self.vocab_size = vocab_size
175
+
176
+ self.image_token_id = image_token_id
177
+ self.image_feature_size = image_feature_size
178
+ self.image_proj_type = image_proj_type
179
+
180
+ self.linear_type = linear_type
181
+ self.fp8_accum_dtype = fp8_accum_dtype
182
+
183
+ self.eval_attention_n_bit = eval_attention_n_bit
184
+ self.eval_mlp_n_bit = eval_mlp_n_bit
185
+ self.use_cache = use_cache
186
+
187
+ # fields for vLLM
188
+ self.sliding_window = attention_window_size
189
+
190
+ super().__init__(
191
+ tokenizer_class=tokenizer_class,
192
+ pad_token_id=pad_token_id,
193
+ bos_token_id=bos_token_id,
194
+ eos_token_id=eos_token_id,
195
+ tie_word_embeddings=tie_word_embeddings,
196
+ **kwargs,
197
+ )
198
+
199
+
200
+ class PlamoAttentionCache(torch.nn.Module):
201
+ def __init__(self, key: torch.Tensor, value: torch.Tensor) -> None:
202
+ super().__init__()
203
+ B, nh, L, c = key.shape
204
+ assert len(value.shape) == 4
205
+ assert value.shape[0] == B
206
+ assert value.shape[2] == L
207
+ self.register_parameter("key", torch.nn.Parameter(key, requires_grad=False))
208
+ self.register_parameter("value", torch.nn.Parameter(value, requires_grad=False))
209
+
210
+
211
+ class PlamoMambaCache(torch.nn.Module):
212
+ def __init__(self, conv_state: torch.Tensor, ssm_state: torch.Tensor) -> None:
213
+ super().__init__()
214
+ # conv_state: [B, C, d_conv]
215
+ # ssm_state: [B, nhead, nchanel_per_head, d_state]
216
+ assert len(conv_state.shape) == 3
217
+ assert len(ssm_state.shape) == 4
218
+ assert conv_state.shape[0] == ssm_state.shape[0]
219
+ self.register_parameter("conv_state", torch.nn.Parameter(conv_state, requires_grad=False))
220
+ self.register_parameter("ssm_state", torch.nn.Parameter(ssm_state, requires_grad=False))
221
+
222
+
223
+ PlamoLayerCache = PlamoAttentionCache | PlamoMambaCache
224
+
225
+
226
+ class PlamoCache(torch.nn.Module):
227
+ """
228
+ stores states of the model for fast decoding.
229
+ `transformers` uses `transformers.Cache` for this purpose, but the interface and variable names are
230
+ deeply dependent on Transformers architecture (e.g., `key_states`) and it is difficult to use
231
+ other architectures (e.g., Mamba).
232
+ This class provides a similar interface to `transformers.Cache`, but is designed to also handle
233
+ the state of Mamba properly.
234
+ """
235
+
236
+ def __init__(self, config: PlamoConfig) -> None:
237
+ super().__init__()
238
+ self.config = config
239
+ self.cache = torch.nn.ModuleList([None for _ in range(config.num_hidden_layers)]) # type: ignore
240
+
241
+ def append_kv(self, key: torch.Tensor, value: torch.Tensor, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor]:
242
+ c = self.cache[layer_idx]
243
+ if c is None:
244
+ return key, value
245
+ assert isinstance(c, PlamoAttentionCache)
246
+
247
+ def _validate(cache: torch.Tensor, new_tensor: torch.Tensor) -> None:
248
+ assert len(cache.shape) == 4
249
+ assert len(new_tensor.shape) == 4
250
+ assert cache.shape[0] == new_tensor.shape[0]
251
+ assert cache.shape[1] == new_tensor.shape[1]
252
+ assert cache.shape[3] == new_tensor.shape[3]
253
+
254
+ _validate(c.key, key)
255
+ _validate(c.value, value)
256
+ assert key.shape[2] == value.shape[2]
257
+ return torch.cat([c.key, key], dim=2), torch.cat([c.value, value], dim=2)
258
+
259
+ def update_attention(
260
+ self, key_states: torch.Tensor, value_states: torch.Tensor, layer_idx: int
261
+ ) -> PlamoAttentionCache:
262
+ full_attn = layer_idx in self.config.full_attention_idx
263
+ window_size = self.config.attention_window_size
264
+
265
+ if self.cache[layer_idx] is None:
266
+ if full_attn:
267
+ self.cache[layer_idx] = PlamoAttentionCache(key_states, value_states)
268
+ else:
269
+ self.cache[layer_idx] = PlamoAttentionCache(
270
+ key_states[:, :, -window_size:, :], value_states[:, :, -window_size:, :]
271
+ )
272
+ else:
273
+ c = self.cache[layer_idx]
274
+ assert isinstance(c, PlamoAttentionCache)
275
+ k, v = self.append_kv(key_states, value_states, layer_idx)
276
+ if full_attn:
277
+ c.key.data = k
278
+ c.value.data = v
279
+ else:
280
+ c.key.data = k[:, :, -window_size:, :]
281
+ c.value.data = v[:, :, -window_size:, :]
282
+ return self.cache[layer_idx] # type: ignore
283
+
284
+ def update_mamba(self, conv_state: torch.Tensor, ssm_state: torch.Tensor, layer_idx: int) -> PlamoMambaCache:
285
+ if self.cache[layer_idx] is None:
286
+ self.cache[layer_idx] = PlamoMambaCache(conv_state, ssm_state)
287
+ else:
288
+ c = self.cache[layer_idx]
289
+ assert isinstance(c, PlamoMambaCache)
290
+ assert c.conv_state.shape == conv_state.shape
291
+ assert c.ssm_state.shape == ssm_state.shape
292
+ c.conv_state.data = conv_state
293
+ c.ssm_state.data = ssm_state
294
+ return self.cache[layer_idx] # type: ignore
295
+
296
+ def __getitem__(self, layer_idx: int) -> PlamoLayerCache | None:
297
+ assert layer_idx < len(self.cache)
298
+ layer_cache = self.cache[layer_idx]
299
+ return layer_cache # type: ignore
300
+
301
+ def __len__(self) -> int:
302
+ return len(self.cache)
303
+
304
+ def get_seq_length(self, layer_idx: Optional[int] = None) -> int:
305
+ if layer_idx is not None:
306
+ c = self.cache[layer_idx]
307
+ assert isinstance(c, PlamoAttentionCache)
308
+ return c.key.shape[2] # type: ignore
309
+
310
+ sequence_length: int | None = None
311
+ for layer_cache in self.cache:
312
+ if isinstance(layer_cache, PlamoAttentionCache):
313
+ sequence_length = (
314
+ max(layer_cache.key.shape[2], sequence_length)
315
+ if sequence_length is not None
316
+ else layer_cache.key.shape[2]
317
+ )
318
+ assert sequence_length is not None
319
+ return sequence_length
320
+
321
+ def get_max_length(self) -> int | None:
322
+ return None
323
+
324
+ def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int:
325
+ """Given the sequence length of the new inputs, returns the usable length of the cache."""
326
+ # Cache without size limit -> all cache is usable
327
+ # Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache
328
+ # length, we will need to evict part of the cache (and thus not all cache is usable)
329
+ max_length = self.get_max_length()
330
+ previous_seq_length = self.get_seq_length(layer_idx)
331
+ if max_length is not None and previous_seq_length + new_seq_length > max_length:
332
+ return max_length - new_seq_length
333
+ return previous_seq_length
334
+
335
+ def reorder_cache(self, beam_idx: torch.Tensor) -> None:
336
+ def _mamba(cache: PlamoMambaCache) -> PlamoMambaCache:
337
+ return PlamoMambaCache(
338
+ conv_state=cache.conv_state.index_select(0, beam_idx),
339
+ ssm_state=cache.ssm_state.index_select(0, beam_idx),
340
+ )
341
+
342
+ def _attention(cache: PlamoAttentionCache) -> PlamoAttentionCache:
343
+ return PlamoAttentionCache(
344
+ key=cache.key.index_select(0, beam_idx),
345
+ value=cache.value.index_select(0, beam_idx),
346
+ )
347
+
348
+ for i in range(len(self.cache)):
349
+ if self.cache[i] is None:
350
+ continue
351
+ layer_cache = self.cache[i]
352
+ if isinstance(layer_cache, PlamoMambaCache):
353
+ self.cache[i] = _mamba(layer_cache)
354
+ else:
355
+ assert isinstance(layer_cache, PlamoAttentionCache)
356
+ self.cache[i] = _attention(layer_cache)
357
+
358
+ @property
359
+ def seen_tokens(self) -> int | None:
360
+ return None
361
+
362
+
363
+ class DecoderInput(NamedTuple):
364
+ hidden_states: torch.Tensor
365
+ attention_mask: Optional[torch.Tensor] = None
366
+ past_states: Optional[PlamoCache] = None
367
+ output_hidden_states: Optional[bool] = False
368
+ output_attentions: Optional[bool] = False
369
+ gradient_checkpointing: bool = False
370
+ input_ids: Optional[torch.Tensor] = None
371
+
372
+
373
+ class DecoderOutput(NamedTuple):
374
+ hidden_states: torch.Tensor
375
+ all_hidden_states: Optional[Tuple[torch.Tensor, ...]]
376
+ all_self_attns: Optional[Tuple[torch.Tensor, ...]]
377
+
378
+
379
+ # Copied from transformers.models.bart.modeling_bart._make_causal_mask
380
+ def _make_causal_mask(
381
+ input_ids_shape: Tuple[int, int], dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
382
+ ) -> torch.Tensor:
383
+ """
384
+ Make causal mask used for bi-directional self-attention.
385
+ """
386
+ bsz, tgt_len = input_ids_shape
387
+ mask = torch.full((tgt_len, tgt_len), float("-inf"), device=device)
388
+ mask_cond = torch.arange(mask.size(-1), device=device)
389
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
390
+ mask = mask.to(dtype)
391
+
392
+ if past_key_values_length > 0:
393
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
394
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
395
+
396
+
397
+ # Copied from transformers.models.bart.modeling_bart._expand_mask
398
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None) -> torch.Tensor:
399
+ """
400
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
401
+ """
402
+ bsz, src_len = mask.size()
403
+ tgt_len = tgt_len if tgt_len is not None else src_len
404
+
405
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
406
+
407
+ inverted_mask = 1.0 - expanded_mask
408
+
409
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), float("-inf")) # type: ignore
410
+
411
+
412
+ def _rms_norm(
413
+ hidden_states: torch.Tensor, weight: Optional[torch.Tensor], eps: float, offset: float = 1.0
414
+ ) -> torch.Tensor:
415
+ input_dtype = hidden_states.dtype
416
+ hidden_states = hidden_states.to(torch.float32)
417
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
418
+ hidden_states = hidden_states * torch.rsqrt(variance + eps)
419
+ hidden_states = hidden_states.to(input_dtype)
420
+ if weight is not None:
421
+ hidden_states = (offset + weight) * hidden_states
422
+ return hidden_states
423
+
424
+
425
+ class RMSNorm(nn.Module):
426
+ def __init__(
427
+ self,
428
+ hidden_size: int,
429
+ eps: float = 1e-6,
430
+ offset: float = 1.0,
431
+ device: Optional[Union[torch.device, str]] = None,
432
+ ) -> None:
433
+ super().__init__()
434
+ self.weight = nn.Parameter(torch.zeros(hidden_size, device=device))
435
+ self.variance_epsilon = eps
436
+ self.offset = offset
437
+
438
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
439
+ return _rms_norm(hidden_states, self.weight, self.variance_epsilon, offset=self.offset)
440
+
441
+
442
+ def get_initial_dt_bias(num_heads: int) -> torch.Tensor:
443
+ dt_min = 0.001
444
+ dt_max = 0.1
445
+ dt = torch.exp(torch.rand(num_heads) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min))
446
+ dt = torch.clamp(dt, 1e-4)
447
+ inv_dt = dt + torch.log(-torch.expm1(-dt))
448
+ return inv_dt
449
+
450
+
451
+ def get_initial_A(num_heads: int) -> torch.Tensor:
452
+ A = torch.arange(1, num_heads + 1, dtype=torch.float32)
453
+ return torch.log(A)
454
+
455
+
456
+ def _bf16_supported_in_triton() -> bool:
457
+ # newer torch (2.2.0 and later?) supports bfloat16 even when using Voltas
458
+ # but triton cannot compile bf16 kernels for Volta
459
+ major, _ = torch.cuda.get_device_capability()
460
+ return major >= 8
461
+
462
+
463
+ def _get_trition_dtype(dtype: torch.dtype) -> torch.dtype:
464
+ if dtype != torch.bfloat16:
465
+ return dtype
466
+ if _bf16_supported_in_triton():
467
+ return dtype
468
+ return torch.float32
469
+
470
+
471
+ def ssd_update_state(
472
+ ssm_state: torch.Tensor,
473
+ x: torch.Tensor,
474
+ dt: torch.Tensor,
475
+ A: torch.Tensor,
476
+ B: torch.Tensor,
477
+ C: torch.Tensor,
478
+ D: torch.Tensor,
479
+ z: torch.Tensor,
480
+ dt_bias: torch.Tensor,
481
+ dt_softplus: bool,
482
+ ) -> torch.Tensor:
483
+ assert ssm_state.dtype == torch.float32
484
+ if dt.is_cuda:
485
+ dtype = _get_trition_dtype(x.dtype)
486
+ else:
487
+ dtype = x.dtype
488
+ if dt.is_cuda:
489
+ f = mamba_ssm.ops.triton.selective_state_update.selective_state_update
490
+ else:
491
+ f = mamba_ssm.ops.triton.selective_state_update.selective_state_update_ref
492
+
493
+ hidden_size_per_head = x.shape[-1]
494
+ d_state = B.shape[-1]
495
+ A = A[:, None, None].expand(-1, hidden_size_per_head, d_state).float()
496
+ dt = dt[..., None].expand(-1, -1, hidden_size_per_head)
497
+ dt_bias = dt_bias[:, None].expand(-1, hidden_size_per_head)
498
+ D = D[:, None].expand(-1, hidden_size_per_head)
499
+ assert ssm_state.dtype == torch.float32
500
+ out = f(
501
+ ssm_state,
502
+ x.to(dtype),
503
+ dt.to(dtype),
504
+ A.float(),
505
+ B.to(dtype),
506
+ C.to(dtype),
507
+ D.float(),
508
+ z.to(dtype),
509
+ dt_bias.float(),
510
+ dt_softplus=dt_softplus,
511
+ )
512
+ return out[:, None] # type: ignore
513
+
514
+
515
+ def _ssd_chunk_scan_combined_naive(
516
+ x: torch.Tensor,
517
+ dt: torch.Tensor,
518
+ A: torch.Tensor,
519
+ B: torch.Tensor,
520
+ C: torch.Tensor,
521
+ D: torch.Tensor,
522
+ z: torch.Tensor,
523
+ dt_bias: torch.Tensor,
524
+ dt_softplus: bool,
525
+ seq_idx: torch.Tensor | None,
526
+ ssm_state: torch.Tensor,
527
+ ) -> tuple[torch.Tensor, torch.Tensor]:
528
+ assert ssm_state.dtype == torch.float32
529
+ length = x.shape[1]
530
+ ys = []
531
+ for i in range(length):
532
+ if i != 0 and seq_idx is not None:
533
+ ssm_state = torch.where(
534
+ (seq_idx[:, i - 1] != seq_idx[:, i])[:, None, None, None],
535
+ torch.zeros_like(ssm_state),
536
+ ssm_state,
537
+ )
538
+ y = ssd_update_state(
539
+ ssm_state,
540
+ x[:, i],
541
+ dt[:, i],
542
+ A,
543
+ B[:, i],
544
+ C[:, i],
545
+ D,
546
+ z=z[:, i],
547
+ dt_bias=dt_bias,
548
+ dt_softplus=dt_softplus,
549
+ )
550
+ ys.append(y)
551
+ return torch.cat(ys, dim=1), ssm_state
552
+
553
+
554
+ def _ssd_chunk_scan_combined_cpu(
555
+ x: torch.Tensor,
556
+ dt: torch.Tensor,
557
+ A: torch.Tensor,
558
+ B: torch.Tensor,
559
+ C: torch.Tensor,
560
+ chunk_size: int,
561
+ D: torch.Tensor,
562
+ z: torch.Tensor,
563
+ dt_bias: torch.Tensor,
564
+ dt_softplus: bool,
565
+ ) -> tuple[torch.Tensor, torch.Tensor]:
566
+ # (bsize, nhead, nchunk, chunk_size)
567
+ dt = dt.float() # We want high precision for this before cumsum
568
+ dt = dt.permute(0, 2, 1).unflatten(2, (-1, chunk_size)) # type: ignore
569
+ if dt_bias is not None:
570
+ dt = dt + dt_bias[None, :, None, None]
571
+ if dt_softplus:
572
+ dt = F.softplus(dt)
573
+ dA = dt * A[None, :, None, None]
574
+ dA_cumsum = torch.cumsum(dA, dim=-1)
575
+
576
+ _, _, nheads, _ = x.shape
577
+ dstate = B.shape[-1]
578
+ _ = dt.shape[2]
579
+
580
+ with torch.profiler.record_function("ssd_chunk_scan_combined_cpu_chunk_state"):
581
+ # Following is equivalent to `mamba_ssm.ops.triton.ssd_combined.chunk_state_ref(B, x, dt, dA_cumsum)`
582
+ # But `einsum` in the above function is too slow in CPU.
583
+ x_ = torch.unflatten(x, 1, (-1, chunk_size))
584
+ assert B.shape[2] == nheads # B should be already expanded
585
+ B_ = torch.unflatten(B, 1, (-1, chunk_size)).to(x.dtype) # (bsize, nchunk, chunk_size, nheads, dstate)
586
+ decay_states = torch.exp((dA_cumsum[:, :, :, -1:] - dA_cumsum)).to(x.dtype)
587
+ dt_ = dt.to(x.dtype)
588
+
589
+ # einsum("bclhn,bhcl,bhcl,bclhp->bchpn", B_, decay_states, dt_, x_)
590
+ B_ = B_.permute(0, 1, 3, 4, 2) # bchnl
591
+ tmp = dt_ * decay_states # bhcl
592
+ tmp = tmp.permute(0, 2, 1, 3)[:, :, :, None] # bch1l
593
+ tmp = B_ * tmp # bchnl
594
+ x_ = x_.permute(0, 1, 3, 2, 4) # bchlp
595
+ tmp = tmp @ x_ # bchnp
596
+ states = tmp.permute(0, 1, 2, 4, 3) # bchpn
597
+
598
+ states_dtype = states.dtype
599
+ if states.dtype not in [torch.float32, torch.float64]:
600
+ states = states.to(torch.float32)
601
+ with torch.profiler.record_function("ssd_chunk_scan_combined_cpu_state_passing"):
602
+ out, last_state = mamba_ssm.ops.triton.ssd_combined.state_passing_ref(
603
+ states.flatten(start_dim=-2, end_dim=-1),
604
+ dA_cumsum[:, :, :, -1],
605
+ )
606
+ states = torch.unflatten(out, -1, (-1, dstate))
607
+ last_state = torch.unflatten(last_state, -1, (-1, dstate))
608
+ states = states.to(states_dtype)
609
+ with torch.profiler.record_function("ssd_chunk_scan_combined_cpu_chunk_scan"):
610
+ out = mamba_ssm.ops.triton.ssd_combined.chunk_scan_ref(B, C, x, dt, dA_cumsum, states, D=D, z=z)
611
+
612
+ return out, last_state
613
+
614
+
615
+ @torch.profiler.record_function("ssd_chunk_scan_combined")
616
+ def ssd_chunk_scan_combined(
617
+ x: torch.Tensor,
618
+ dt: torch.Tensor,
619
+ A: torch.Tensor,
620
+ B: torch.Tensor,
621
+ C: torch.Tensor,
622
+ chunk_size: int,
623
+ D: torch.Tensor,
624
+ z: torch.Tensor,
625
+ dt_bias: torch.Tensor,
626
+ dt_softplus: bool,
627
+ return_final_states: bool,
628
+ seq_idx: torch.Tensor | None,
629
+ ssm_state: torch.Tensor | None,
630
+ ) -> tuple[torch.Tensor, torch.Tensor] | torch.Tensor:
631
+ if seq_idx is not None:
632
+ assert seq_idx.dtype == torch.int32
633
+ assert ssm_state is None
634
+ assert not return_final_states
635
+ if ssm_state is not None:
636
+ assert ssm_state.dtype == torch.float32
637
+ assert seq_idx is None
638
+
639
+ length = x.shape[1]
640
+
641
+ """
642
+ state will be updates by following:
643
+ ```
644
+ dt = softplus(dt)
645
+ dA = exp(dt * A)
646
+ state_next = state * dA + dB * x
647
+ ```
648
+
649
+ To avoid updating state, we set dt to -inf and x to 0
650
+ because `softplus(-inf) = 0` and `exp(0) = 1`
651
+ """
652
+ pad = (chunk_size - length % chunk_size) % chunk_size
653
+ x = torch.nn.functional.pad(x, pad=[0, 0, 0, 0, pad, 0], value=0.0)
654
+ dt = torch.nn.functional.pad(dt, pad=[0, 0, pad, 0], value=float("-inf"))
655
+ B = torch.nn.functional.pad(B, pad=[0, 0, 0, 0, pad, 0], value=0.0)
656
+ C = torch.nn.functional.pad(C, pad=[0, 0, 0, 0, pad, 0], value=0.0)
657
+ z = torch.nn.functional.pad(z, pad=[0, 0, 0, 0, pad, 0], value=0.0)
658
+ if seq_idx is not None:
659
+ seq_idx = torch.nn.functional.pad(seq_idx, pad=[pad, 0], value=0)
660
+
661
+ length = x.shape[1]
662
+ assert length % chunk_size == 0, (length, chunk_size)
663
+
664
+ if dt.is_cuda:
665
+ dtype = _get_trition_dtype(x.dtype)
666
+ out = mamba_ssm.ops.triton.ssd_combined.mamba_chunk_scan_combined( # type: ignore
667
+ x.to(dtype),
668
+ dt.to(dtype),
669
+ A.float(),
670
+ B.to(dtype),
671
+ C.to(dtype),
672
+ chunk_size,
673
+ D=D.float(),
674
+ z=z.to(dtype),
675
+ initial_states=ssm_state,
676
+ dt_bias=dt_bias.float(),
677
+ dt_softplus=dt_softplus,
678
+ seq_idx=seq_idx,
679
+ return_final_states=return_final_states,
680
+ )
681
+ if return_final_states:
682
+ return out[0][:, pad:], out[1]
683
+ else:
684
+ assert isinstance(out, torch.Tensor)
685
+ return out[:, pad:]
686
+ else:
687
+ if ssm_state is None and seq_idx is None:
688
+ tmp = _ssd_chunk_scan_combined_cpu(
689
+ x,
690
+ dt,
691
+ A,
692
+ B,
693
+ C,
694
+ chunk_size,
695
+ D=D,
696
+ z=z,
697
+ dt_bias=dt_bias.float(),
698
+ dt_softplus=dt_softplus,
699
+ )
700
+ else:
701
+ if ssm_state is None:
702
+ bsize, _, num_heads, channel = x.shape
703
+ state = B.shape[-1]
704
+ ssm_state = torch.zeros(bsize, num_heads, channel, state, dtype=torch.float32, device=x.device)
705
+ tmp = _ssd_chunk_scan_combined_naive(
706
+ x, dt, A, B, C, D, z=z, dt_bias=dt_bias, dt_softplus=dt_softplus, seq_idx=seq_idx, ssm_state=ssm_state
707
+ )
708
+ tmp = (tmp[0][:, pad:], tmp[1])
709
+ if return_final_states:
710
+ return tmp
711
+ else:
712
+ return tmp[0]
713
+
714
+
715
+ def _causal_conv1d_update(
716
+ conv_state: torch.Tensor, weight: torch.Tensor, xBC: torch.Tensor
717
+ ) -> tuple[torch.Tensor, torch.Tensor]:
718
+ dtype = conv_state.dtype
719
+ xBC = xBC.to(dtype)
720
+ weight = weight.to(dtype)
721
+ if conv_state.is_cuda:
722
+ x = causal_conv1d.causal_conv1d_update(
723
+ x=xBC,
724
+ conv_state=conv_state,
725
+ weight=weight[:, 0, :],
726
+ activation="silu",
727
+ )
728
+ return x, conv_state
729
+ else:
730
+ x = causal_conv1d.causal_conv1d_update_ref(
731
+ x=xBC,
732
+ conv_state=conv_state,
733
+ weight=weight[:, 0, :],
734
+ activation="silu",
735
+ )
736
+ return x, conv_state
737
+
738
+
739
+ def _causal_conv1d_naive(
740
+ conv_state: torch.Tensor, weight: torch.Tensor, x: torch.Tensor, seq_idx: torch.Tensor | None
741
+ ) -> tuple[torch.Tensor, torch.Tensor]:
742
+ length = x.shape[-1]
743
+ out = torch.zeros_like(x)
744
+ for i in range(length):
745
+ if i != 0 and seq_idx is not None:
746
+ conv_state = torch.where(
747
+ (seq_idx[:, i - 1] != seq_idx[:, i])[:, None, None],
748
+ torch.zeros_like(conv_state),
749
+ conv_state,
750
+ )
751
+ out[:, :, i : i + 1], conv_state = _causal_conv1d_update(conv_state, weight, x[:, :, i : i + 1])
752
+ return out, conv_state
753
+
754
+
755
+ @torch.profiler.record_function("causal_conv1d")
756
+ def _causal_conv1d(
757
+ conv_state: torch.Tensor | None, weight: torch.Tensor, x: torch.Tensor, seq_idx: torch.Tensor | None
758
+ ) -> tuple[torch.Tensor, torch.Tensor | None]:
759
+ dtype = x.dtype
760
+ if conv_state is not None:
761
+ dtype = conv_state.dtype
762
+ assert seq_idx is None
763
+ if seq_idx is not None:
764
+ assert seq_idx.dtype == torch.int32
765
+ assert conv_state is None
766
+ weight = weight.to(dtype)
767
+ x = x.to(dtype)
768
+
769
+ return_final_states = conv_state is not None
770
+ if weight.is_cuda:
771
+ if x.stride(1) != 1:
772
+ # to channel-last format
773
+ x = x.transpose(-1, -2).contiguous().transpose(-1, -2)
774
+ if conv_state is not None:
775
+ if conv_state.stride(1) != 1:
776
+ # to channel-last format
777
+ conv_state = conv_state.transpose(-1, -2).contiguous().transpose(-1, -2)
778
+ tmp = causal_conv1d.causal_conv1d_fn(
779
+ x=x,
780
+ weight=weight[:, 0, :],
781
+ initial_states=conv_state,
782
+ return_final_states=conv_state is not None,
783
+ activation="silu",
784
+ seq_idx=seq_idx,
785
+ )
786
+ if conv_state is not None:
787
+ x, conv_state = tmp
788
+ else:
789
+ x = tmp
790
+ else:
791
+ if seq_idx is None:
792
+ x, conv_state = causal_conv1d.causal_conv1d_ref(
793
+ x=x,
794
+ initial_states=conv_state,
795
+ return_final_states=True,
796
+ weight=weight[:, 0, :],
797
+ activation="silu",
798
+ )
799
+ else:
800
+ if conv_state is None:
801
+ bsize = x.shape[0]
802
+ dim = weight.shape[0]
803
+ d_conv = weight.shape[-1]
804
+ conv_state = torch.zeros(bsize, dim, d_conv - 1, dtype=x.dtype, device=x.device)
805
+ x, conv_state = _causal_conv1d_naive(conv_state, weight, x, seq_idx)
806
+ if return_final_states:
807
+ return x, conv_state
808
+ else:
809
+ return x, None
810
+
811
+
812
+ class Mamba(torch.nn.Module):
813
+ def __init__(self, config: PlamoConfig, layer_idx: int) -> None:
814
+ super().__init__()
815
+ self.config = config
816
+ self.layer_idx = layer_idx
817
+ self.hidden_size = config.hidden_size
818
+ self.d_state = config.mamba_d_state
819
+ self.d_conv = config.mamba_d_conv
820
+ self.chunk_size = config.mamba_chunk_size
821
+ self.num_heads = config.mamba_num_heads
822
+ # TODO add mamba_hidden_size_per_head config (?)
823
+ self.hidden_size_per_head = config.hidden_size_per_head
824
+
825
+ self.intermediate_size = self.num_heads * self.hidden_size_per_head
826
+
827
+ self.in_proj = torch.nn.Linear(self.hidden_size, 2 * self.intermediate_size, bias=False)
828
+ self.conv1d = torch.nn.Conv1d(
829
+ in_channels=self.intermediate_size,
830
+ out_channels=self.intermediate_size,
831
+ bias=False, # TODO the original implementation uses bias
832
+ kernel_size=self.d_conv,
833
+ groups=self.intermediate_size,
834
+ padding=0,
835
+ )
836
+ self.dt_dim = max(64, self.hidden_size // 16)
837
+ # Notes:
838
+ # Mamba2 removes this linear projection for simplicity (Figure 6 in the paper),
839
+ # but it may degrade the ability of content-length extrapolation.
840
+ self.bcdt_proj = torch.nn.Linear(
841
+ self.intermediate_size,
842
+ self.dt_dim + 2 * self.d_state,
843
+ bias=False,
844
+ )
845
+ self.dt_proj = torch.nn.Linear(self.dt_dim, self.num_heads, bias=False)
846
+
847
+ self.dt_bias = torch.nn.Parameter(get_initial_dt_bias(self.num_heads))
848
+ self.A_log = torch.nn.Parameter(get_initial_A(self.num_heads))
849
+ self.D = torch.nn.Parameter(torch.ones(self.num_heads))
850
+
851
+ # TODO norm weight before gating like Mamba2
852
+ self.dt_norm_weight = torch.nn.Parameter(torch.ones(self.dt_dim))
853
+ self.B_norm_weight = torch.nn.Parameter(torch.ones(self.d_state))
854
+ self.C_norm_weight = torch.nn.Parameter(torch.ones(self.d_state))
855
+
856
+ self.out_proj = torch.nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
857
+
858
+ def _no_weight_decay_param_names(self) -> set[str]:
859
+ return set(["D", "dt_bias", "A_log"])
860
+
861
+ def forward(
862
+ self,
863
+ hidden_states: torch.Tensor,
864
+ attention_mask: Optional[torch.Tensor] = None,
865
+ past_states: Optional[PlamoCache] = None,
866
+ ) -> Tuple[torch.Tensor, Optional[PlamoCache]]:
867
+ bsize, length, _ = hidden_states.shape
868
+ is_update = length == 1 and past_states is not None
869
+
870
+ bool_mask: torch.Tensor | None = None
871
+ seq_idx: torch.Tensor | None = None
872
+ if attention_mask is not None:
873
+ if len(attention_mask.shape) == 2:
874
+ attention_mask = attention_mask[None, None].expand(bsize, 1, -1, -1)
875
+ assert len(attention_mask.shape) == 4
876
+
877
+ if past_states is None:
878
+ # TODO: support seq_idx with cache
879
+ bool_mask_4d = attention_mask == 0
880
+ is_first_token = _is_first_token(bool_mask_4d)[:, 0, :]
881
+ seq_idx = torch.cumsum(is_first_token, dim=-1) - 1
882
+ seq_idx = seq_idx.to(torch.int32)
883
+
884
+ # `generate` function creates attention mask that contains past tokens,
885
+ # but mamba does not use them
886
+ attention_mask = attention_mask[:, 0, -length:, -length:]
887
+ bool_mask = torch.diagonal(attention_mask, dim1=-2, dim2=-1) == 0
888
+
889
+ conv_state: torch.Tensor | None
890
+ ssm_state: torch.Tensor | None
891
+ if past_states is None:
892
+ conv_state = None
893
+ ssm_state = None
894
+ elif past_states[self.layer_idx] is None:
895
+ conv_state = torch.zeros(
896
+ bsize, self.intermediate_size, self.d_conv - 1, dtype=hidden_states.dtype, device=hidden_states.device
897
+ )
898
+ ssm_state = torch.zeros(
899
+ bsize,
900
+ self.num_heads,
901
+ self.hidden_size_per_head,
902
+ self.d_state,
903
+ dtype=torch.float32,
904
+ device=hidden_states.device,
905
+ )
906
+ else:
907
+ c = past_states[self.layer_idx]
908
+ assert isinstance(c, PlamoMambaCache)
909
+ conv_state = c.conv_state
910
+ ssm_state = c.ssm_state
911
+
912
+ zx = self.in_proj(hidden_states)
913
+ zx = zx.reshape(bsize, length, self.num_heads, -1)
914
+ # z: (bsize, length, num_heads, hidden_size_per_head)
915
+ # x: (bsize, length, num_heads, hidden_size_per_head)
916
+ z, x = torch.split(zx, [self.hidden_size_per_head, self.hidden_size_per_head], dim=-1)
917
+
918
+ # conv
919
+ x = x.reshape(bsize, length, -1).transpose(1, 2) # (bsize, intermediate_size, length)
920
+ if bool_mask is not None:
921
+ x = torch.where(bool_mask[:, None, :], x, 0.0)
922
+ if is_update:
923
+ assert conv_state is not None
924
+ x, conv_state = _causal_conv1d_update(conv_state, self.conv1d.weight, x)
925
+ else:
926
+ x, conv_state = _causal_conv1d(conv_state, self.conv1d.weight, x, seq_idx=seq_idx)
927
+ x = x.to(dtype=hidden_states.dtype)
928
+ x = x.transpose(1, 2) # (bsize, length, intermediate_size)
929
+ x = x.reshape(bsize, length, -1)
930
+ # x: (bsize, length, num_heads, hidden_size_per_head)
931
+ # B: (bsize, length, 1, d_state)
932
+ # C: (bsize, length, 1, d_state)
933
+ # dt: (bsize, length, dt_dim)
934
+ BCdt = self.bcdt_proj(x)
935
+ x = x.reshape(bsize, length, self.num_heads, -1)
936
+ B, C, dt = torch.split(BCdt, [self.d_state, self.d_state, self.dt_dim], dim=-1)
937
+ B = B[:, :, None, :]
938
+ C = C[:, :, None, :]
939
+
940
+ A = -torch.exp(self.A_log.float()) # (num_heads,)
941
+ dt = _rms_norm(dt, None, self.config.rms_norm_eps) * self.dt_norm_weight[None, None, :]
942
+ B = _rms_norm(B, None, self.config.rms_norm_eps) * self.B_norm_weight[None, None, None, :]
943
+ C = _rms_norm(C, None, self.config.rms_norm_eps) * self.C_norm_weight[None, None, None, :]
944
+
945
+ # (bsize, length, num_heads, 1)
946
+ dt = self.dt_proj(dt)[..., None]
947
+
948
+ # TODO it may not be required
949
+ B = B.expand(-1, -1, self.num_heads, -1)
950
+ C = C.expand(-1, -1, self.num_heads, -1)
951
+
952
+ if bool_mask is not None:
953
+ """
954
+ state will be updates by following:
955
+ ```
956
+ dt = softplus(dt)
957
+ dA = exp(dt * A)
958
+ state_next = state * dA + dB * x
959
+ ```
960
+
961
+ To avoid updating state, we set dt to -inf and x to 0
962
+ because `softplus(-inf) = 0` and `exp(0) = 1`
963
+ """
964
+ dt = torch.where(bool_mask[:, :, None, None], dt, float("-inf"))
965
+ x = torch.where(bool_mask[:, :, None, None], x, 0.0)
966
+
967
+ # ssm
968
+ if is_update:
969
+ assert ssm_state is not None
970
+ out = ssd_update_state(
971
+ ssm_state,
972
+ x[:, 0],
973
+ dt[:, 0].reshape(bsize, -1),
974
+ A,
975
+ B[:, 0],
976
+ C[:, 0],
977
+ D=self.D,
978
+ z=z[:, 0],
979
+ dt_bias=self.dt_bias,
980
+ dt_softplus=True,
981
+ )
982
+ else:
983
+ tmp = ssd_chunk_scan_combined(
984
+ x,
985
+ dt.reshape(bsize, length, -1),
986
+ A,
987
+ B,
988
+ C,
989
+ self.chunk_size,
990
+ D=self.D,
991
+ z=z,
992
+ dt_bias=self.dt_bias,
993
+ dt_softplus=True,
994
+ return_final_states=past_states is not None,
995
+ seq_idx=seq_idx,
996
+ ssm_state=ssm_state,
997
+ )
998
+ if past_states is not None:
999
+ out, ssm_state = tmp
1000
+ else:
1001
+ assert isinstance(tmp, torch.Tensor)
1002
+ out = tmp
1003
+
1004
+ y = self.out_proj(out.reshape(bsize, length, -1))
1005
+
1006
+ if past_states is not None:
1007
+ assert ssm_state is not None
1008
+ assert conv_state is not None
1009
+ past_states.update_mamba(conv_state, ssm_state, self.layer_idx)
1010
+
1011
+ return y, past_states
1012
+
1013
+
1014
+ def swa_mask(q_len: int, kv_len: int, device: torch.device, window_size: int) -> torch.Tensor:
1015
+ max_len = max(q_len, kv_len)
1016
+ mask = (
1017
+ torch.ones(max_len, max_len, dtype=torch.bool, device=device)
1018
+ .triu(diagonal=-window_size)
1019
+ .tril(diagonal=window_size)
1020
+ )
1021
+ return mask[-q_len:, -kv_len:]
1022
+
1023
+
1024
+ class Attention(torch.nn.Module):
1025
+ def __init__(self, config: PlamoConfig, layer_idx: int) -> None:
1026
+ super().__init__()
1027
+ self.config = config
1028
+ self.layer_idx = layer_idx
1029
+ self.hidden_size = config.hidden_size
1030
+ head_dim = config.hidden_size_per_head
1031
+ self.max_position_embeddings = config.max_position_embeddings
1032
+
1033
+ self.q_num_heads = config.num_attention_heads
1034
+ self.qk_dim = self.v_dim = head_dim
1035
+ self.k_num_heads = self.v_num_heads = config.num_key_value_heads
1036
+ assert self.q_num_heads % self.k_num_heads == 0
1037
+ self.n_group = self.q_num_heads // self.k_num_heads
1038
+
1039
+ self.q_proj_dim = self.q_num_heads * self.qk_dim
1040
+ self.k_proj_dim = self.k_num_heads * self.qk_dim
1041
+ self.v_proj_dim = self.k_num_heads * self.v_dim
1042
+ self.qkv_proj = nn.Linear(self.hidden_size, self.q_proj_dim + self.k_proj_dim + self.v_proj_dim, bias=False)
1043
+ self.o_proj = nn.Linear(self.q_num_heads * self.v_dim, self.hidden_size, bias=False)
1044
+
1045
+ self.q_weight = torch.nn.Parameter(torch.ones((self.q_num_heads, self.qk_dim)))
1046
+ self.k_weight = torch.nn.Parameter(torch.ones((self.k_num_heads, self.qk_dim)))
1047
+
1048
+ self.rotary_emb = RotaryEmbedding(self.qk_dim, max_position_embeddings=self.config.attention_window_size)
1049
+
1050
+ def forward(
1051
+ self,
1052
+ hidden_states: torch.Tensor,
1053
+ attention_mask: Optional[torch.Tensor] = None,
1054
+ past_states: Optional[PlamoCache] = None,
1055
+ output_attentions: bool = False,
1056
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[PlamoCache]]:
1057
+ bsz, q_len, _ = hidden_states.size()
1058
+
1059
+ qkv = self.qkv_proj(hidden_states)
1060
+ query_states, key_states, value_states = torch.split(
1061
+ qkv, [self.q_proj_dim, self.k_proj_dim, self.v_proj_dim], dim=-1
1062
+ )
1063
+ query_states = query_states.view(bsz, q_len, self.q_num_heads, self.qk_dim).transpose(1, 2)
1064
+ key_states = key_states.view(bsz, q_len, self.k_num_heads, self.qk_dim).transpose(1, 2)
1065
+ value_states = value_states.view(bsz, q_len, self.v_num_heads, self.v_dim).transpose(1, 2)
1066
+
1067
+ attn_dtype = query_states.dtype
1068
+
1069
+ query_states = _rms_norm(query_states, None, 1e-6) * self.q_weight[None, :, None]
1070
+ key_states = _rms_norm(key_states, None, 1e-6) * self.k_weight[None, :, None]
1071
+
1072
+ if past_states is not None:
1073
+ # reuse k, v, self_attention
1074
+ key_states_new = key_states
1075
+ value_states_new = value_states
1076
+ key_states, value_states = past_states.append_kv(key_states, value_states, self.layer_idx) # type: ignore
1077
+ past_states.update_attention(key_states_new, value_states_new, self.layer_idx)
1078
+
1079
+ kv_seq_len = key_states.shape[-2]
1080
+ device = hidden_states.device
1081
+ position_ids = torch.arange(kv_seq_len, dtype=torch.long, device=device)[None]
1082
+ q_position_ids = position_ids[:, -query_states.shape[2] :]
1083
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
1084
+ query_states = _rotary_pos_emb(query_states, cos, sin, q_position_ids)
1085
+ key_states = _rotary_pos_emb(key_states, cos, sin, position_ids)
1086
+ # [bsz, nh, t, hd]
1087
+
1088
+ def _expand_kv(t: torch.Tensor, repeat: int, target: int) -> torch.Tensor:
1089
+ t = torch.repeat_interleave(t, repeat, dim=1)
1090
+ return t[:, :target]
1091
+
1092
+ # expand shared kv
1093
+ assert self.k_num_heads == self.v_num_heads
1094
+ key_states = _expand_kv(key_states, self.n_group, self.q_num_heads)
1095
+ value_states = _expand_kv(value_states, self.n_group, self.q_num_heads)
1096
+
1097
+ full_attn = self.layer_idx in self.config.full_attention_idx
1098
+
1099
+ query_states = query_states.to(attn_dtype)
1100
+ key_states = key_states.to(attn_dtype)
1101
+ value_states = value_states.to(attn_dtype)
1102
+ if attention_mask is not None and attention_mask.dtype != torch.bool:
1103
+ attention_mask = attention_mask.to(attn_dtype)
1104
+ if attention_mask is None:
1105
+ if not full_attn:
1106
+ assert key_states.shape[2] <= self.config.attention_window_size + 1
1107
+ attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, is_causal=True)
1108
+ else:
1109
+ if attention_mask.dtype == torch.bool:
1110
+ attention_mask = torch.where(attention_mask, torch.tensor(0.0, dtype=torch.float), float("-inf"))
1111
+ if len(attention_mask.shape) == 2:
1112
+ attention_mask = attention_mask[None, None]
1113
+ assert len(attention_mask.shape) == 4
1114
+
1115
+ if not full_attn:
1116
+ m_swa = swa_mask(
1117
+ query_states.shape[2], key_states.shape[2], query_states.device, self.config.attention_window_size
1118
+ )
1119
+ # `generate` function creates attention mask that does not consider sliding window
1120
+ m_swa = m_swa[None, None]
1121
+ attention_mask = attention_mask[:, :, -query_states.shape[2] :, -key_states.shape[2] :]
1122
+ attention_mask = torch.where(m_swa, attention_mask, float("-inf"))
1123
+
1124
+ # like AttentionMaskConverter._unmask_unattended in huggingface.transfoermers,
1125
+ # we need to attend to all tokens in masked rows for `scaled_dot_product_attention`
1126
+ bool_mask = torch.logical_not(torch.isneginf(attention_mask))
1127
+ valid_tokens = torch.sum(bool_mask, dim=-1).bool() # (..., q_len)
1128
+ attention_mask = torch.where(valid_tokens[..., None], attention_mask, float(0.0))
1129
+ attn_output = F.scaled_dot_product_attention(
1130
+ query_states, key_states, value_states, attn_mask=attention_mask
1131
+ )
1132
+
1133
+ attn_output = attn_output.transpose(1, 2)
1134
+
1135
+ attn_output = attn_output.reshape(bsz, q_len, self.q_num_heads * self.v_dim)
1136
+ attn_output = self.o_proj(attn_output)
1137
+
1138
+ if not output_attentions:
1139
+ attn_weights = None
1140
+
1141
+ return attn_output, attn_weights, past_states
1142
+
1143
+
1144
+ class MLP(nn.Module):
1145
+ def __init__(self, config: PlamoConfig) -> None:
1146
+ super().__init__()
1147
+ self.config = config
1148
+ self.hidden_size = config.hidden_size
1149
+ self.intermediate_size = config.intermediate_size
1150
+ self.gate_up_proj = torch.nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=False)
1151
+ self.down_proj = torch.nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
1152
+
1153
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
1154
+ h = self.gate_up_proj(x)
1155
+ h = _swiglu(h)
1156
+ return self.down_proj(h) # type: ignore
1157
+
1158
+
1159
+ class PlamoDecoderLayer(torch.nn.Module):
1160
+ def __init__(self, config: PlamoConfig, is_mamba: bool, layer_idx: int) -> None:
1161
+ super().__init__()
1162
+ self.config = config
1163
+ self.hidden_size = config.hidden_size
1164
+ self.is_mamba = is_mamba
1165
+ self.mixer: torch.nn.Module
1166
+ if is_mamba:
1167
+ self.mixer = Mamba(config, layer_idx)
1168
+ else:
1169
+ self.mixer = Attention(config, layer_idx)
1170
+ self.mlp = MLP(config)
1171
+ """
1172
+ Notes: The model performance was degraded when setting all offsets to 1.
1173
+ """
1174
+ self.pre_mixer_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, offset=1.0)
1175
+ self.post_mixer_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, offset=1.0 / 5)
1176
+ self.pre_mlp_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, offset=1.0)
1177
+ self.post_mlp_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, offset=1.0 / (5**1.5))
1178
+
1179
+ def forward(
1180
+ self,
1181
+ hidden_states: torch.Tensor,
1182
+ attention_mask: Optional[torch.Tensor] = None,
1183
+ past_state: Optional[PlamoCache] = None,
1184
+ output_attentions: Optional[bool] = False,
1185
+ ) -> Tuple[Any, ...]:
1186
+ # from LlamaDecoder
1187
+ residual = hidden_states
1188
+ hidden_states = self.pre_mixer_norm(hidden_states)
1189
+
1190
+ # Self Attention
1191
+ if self.is_mamba:
1192
+ hidden_states_sa, present_key_value = self.mixer(
1193
+ hidden_states=hidden_states,
1194
+ attention_mask=attention_mask,
1195
+ past_states=past_state,
1196
+ )
1197
+ self_attn_weights = None
1198
+ else:
1199
+ hidden_states_sa, self_attn_weights, present_key_value = self.mixer(
1200
+ hidden_states=hidden_states,
1201
+ attention_mask=attention_mask,
1202
+ past_states=past_state,
1203
+ output_attentions=output_attentions,
1204
+ )
1205
+
1206
+ hidden_states_sa = self.post_mixer_norm(hidden_states_sa)
1207
+ hidden_states = residual + hidden_states_sa
1208
+
1209
+ residual = hidden_states
1210
+ hidden_states = self.pre_mlp_norm(hidden_states)
1211
+
1212
+ # Fully Connected
1213
+ hidden_states_mlp = self.mlp(hidden_states)
1214
+
1215
+ # Residual
1216
+ hidden_states_mlp = self.post_mlp_norm(hidden_states_mlp)
1217
+ hidden_states = residual + hidden_states_mlp
1218
+
1219
+ outputs: Any = (hidden_states,)
1220
+
1221
+ if output_attentions:
1222
+ outputs += (self_attn_weights,)
1223
+
1224
+ return outputs # type: ignore
1225
+
1226
+
1227
+ def is_mamba(config: PlamoConfig, i: int) -> bool:
1228
+ if not config.mamba_enabled:
1229
+ return False
1230
+ assert config.mamba_step > 1
1231
+ assert i < config.num_hidden_layers
1232
+
1233
+ if config.num_hidden_layers <= (config.mamba_step // 2):
1234
+ # use attention in last layer
1235
+ return i != config.num_hidden_layers - 1
1236
+ return (i % config.mamba_step) != (config.mamba_step // 2)
1237
+
1238
+
1239
+ class PlamoDecoder(torch.nn.Module):
1240
+ def __init__(self, config: PlamoConfig) -> None:
1241
+ super().__init__()
1242
+
1243
+ self.layers = torch.nn.ModuleList(
1244
+ [
1245
+ PlamoDecoderLayer(config, is_mamba=is_mamba(config, i), layer_idx=i)
1246
+ for i in range(config.num_hidden_layers)
1247
+ ]
1248
+ )
1249
+ self.gradient_checkpointing = False
1250
+
1251
+ def forward(self, x: DecoderInput) -> DecoderOutput:
1252
+ all_hidden_states: Optional[Tuple[torch.Tensor, ...]] = () if x.output_hidden_states else None
1253
+ all_self_attns: Optional[Tuple[torch.Tensor, ...]] = () if x.output_attentions else None
1254
+ hidden_states = x.hidden_states
1255
+
1256
+ for decoder_layer in self.layers:
1257
+ if x.output_hidden_states:
1258
+ assert all_hidden_states is not None
1259
+ all_hidden_states += (hidden_states,)
1260
+
1261
+ if self.training and x.gradient_checkpointing:
1262
+ layer_outputs = self._gradient_checkpointing_func(
1263
+ decoder_layer.__call__,
1264
+ hidden_states,
1265
+ x.attention_mask,
1266
+ x.past_states,
1267
+ x.output_attentions,
1268
+ )
1269
+ else:
1270
+ layer_outputs = decoder_layer(
1271
+ hidden_states,
1272
+ attention_mask=x.attention_mask,
1273
+ past_state=x.past_states,
1274
+ output_attentions=x.output_attentions,
1275
+ )
1276
+
1277
+ hidden_states = layer_outputs[0]
1278
+
1279
+ if x.output_attentions:
1280
+ assert layer_outputs[1] is not None
1281
+ assert all_self_attns is not None
1282
+ all_self_attns += (layer_outputs[1],)
1283
+ return DecoderOutput(hidden_states, all_hidden_states, all_self_attns)
1284
+
1285
+
1286
+ class PlamoPreTrainedModel(PreTrainedModel): # type: ignore
1287
+ config_class = PlamoConfig
1288
+ _no_split_modules: List[str]
1289
+ base_model_prefix = "model"
1290
+ supports_gradient_checkpointing = True
1291
+ _no_split_modules = ["PlamoDecoderLayer"]
1292
+ _skip_keys_device_placement = "past_key_values"
1293
+ _keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
1294
+
1295
+ def _init_weights(self, module: torch.nn.Module) -> None:
1296
+ std = 0.02
1297
+ if isinstance(module, nn.Linear):
1298
+ module.weight.data.normal_(mean=0.0, std=std)
1299
+ if module.bias is not None:
1300
+ module.bias.data.zero_()
1301
+ elif isinstance(module, nn.Embedding):
1302
+ module.weight.data.normal_(mean=0.0, std=std)
1303
+ if module.padding_idx is not None:
1304
+ module.weight.data[module.padding_idx].zero_()
1305
+
1306
+
1307
+ class PlamoModel(PlamoPreTrainedModel):
1308
+ def __init__(self, config: PlamoConfig):
1309
+ super().__init__(config)
1310
+ assert config.eval_attention_n_bit is None
1311
+ assert config.eval_mlp_n_bit is None
1312
+
1313
+ self.padding_idx = config.pad_token_id
1314
+ self.vocab_size = config.vocab_size
1315
+
1316
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
1317
+ if config.image_feature_size is not None:
1318
+ if config.image_proj_type == "mlp":
1319
+ self.image_proj = MLPImageProjector(config) # type: ignore
1320
+ elif config.image_proj_type == "linear":
1321
+ self.image_proj = nn.Linear(config.image_feature_size, config.hidden_size, bias=False) # type: ignore
1322
+ else:
1323
+ raise ValueError(f"Unknown image_proj_type: {config.image_proj_type}")
1324
+ self.layers = PlamoDecoder(config) # type: ignore
1325
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1326
+
1327
+ self.gradient_checkpointing = False
1328
+ # Initialize weights and apply final processing
1329
+ self.post_init()
1330
+
1331
+ def get_input_embeddings(self) -> torch.nn.Embedding:
1332
+ return self.embed_tokens
1333
+
1334
+ def set_input_embeddings(self, value: torch.nn.Embedding) -> None:
1335
+ self.embed_tokens = value
1336
+
1337
+ # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
1338
+ def _prepare_decoder_attention_mask(
1339
+ self,
1340
+ attention_mask: torch.Tensor,
1341
+ input_shape: Tuple[int, int],
1342
+ inputs_embeds: Optional[torch.Tensor],
1343
+ past_key_values_length: int,
1344
+ ) -> Optional[torch.Tensor]:
1345
+ # create causal mask
1346
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1347
+ combined_attention_mask: Optional[torch.Tensor] = None
1348
+ if input_shape[-1] > 1:
1349
+ assert inputs_embeds is not None
1350
+ combined_attention_mask = _make_causal_mask(
1351
+ input_shape,
1352
+ inputs_embeds.dtype,
1353
+ device=inputs_embeds.device,
1354
+ past_key_values_length=past_key_values_length,
1355
+ )
1356
+ input_shape = (input_shape[0], combined_attention_mask.shape[2])
1357
+
1358
+ if attention_mask is not None:
1359
+ if attention_mask.dim() == 4:
1360
+ # Custom 4D attention mask
1361
+ expanded_attn_mask = attention_mask
1362
+ else:
1363
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1364
+ assert inputs_embeds is not None
1365
+ expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
1366
+ inputs_embeds.device
1367
+ )
1368
+ combined_attention_mask = (
1369
+ expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
1370
+ )
1371
+
1372
+ return combined_attention_mask
1373
+
1374
+ def forward(
1375
+ self,
1376
+ input_ids: Optional[torch.LongTensor] = None,
1377
+ attention_mask: Optional[torch.Tensor] = None,
1378
+ position_ids: Optional[torch.Tensor] = None,
1379
+ past_key_values: Optional[PlamoCache] = None,
1380
+ inputs_embeds: Optional[torch.Tensor] = None,
1381
+ image_features: Optional[torch.Tensor] = None,
1382
+ use_cache: Optional[bool] = None,
1383
+ output_attentions: Optional[bool] = None,
1384
+ output_hidden_states: Optional[bool] = None,
1385
+ return_dict: Optional[bool] = None,
1386
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
1387
+ assert input_ids is not None
1388
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1389
+ output_hidden_states = (
1390
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1391
+ )
1392
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1393
+
1394
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1395
+
1396
+ # retrieve input_ids and inputs_embeds
1397
+ if input_ids is not None and inputs_embeds is not None:
1398
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
1399
+ elif input_ids is not None:
1400
+ batch_size, seq_length = input_ids.shape
1401
+ else:
1402
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
1403
+
1404
+ seq_length_with_past = seq_length
1405
+ past_key_values_length = 0
1406
+
1407
+ if past_key_values is not None:
1408
+ past_key_values_length = past_key_values.get_seq_length()
1409
+ seq_length_with_past = seq_length_with_past + past_key_values_length
1410
+
1411
+ if inputs_embeds is None:
1412
+ inputs_embeds = self.embed_tokens(input_ids)
1413
+
1414
+ if image_features is not None:
1415
+ assert self.config.image_token_id is not None
1416
+ image_embeds = self.image_proj(image_features)
1417
+ assert image_embeds.shape == inputs_embeds.shape, (image_embeds.shape, inputs_embeds.shape)
1418
+ mask = input_ids == self.config.image_token_id
1419
+ inputs_embeds[mask] = image_embeds[mask]
1420
+
1421
+ # embed positions
1422
+ require_attn_mask = False
1423
+ if not self.training or past_key_values is not None:
1424
+ require_attn_mask = True
1425
+ if seq_length_with_past >= self.config.attention_window_size:
1426
+ require_attn_mask = True
1427
+ if require_attn_mask and attention_mask is None:
1428
+ attention_mask = torch.ones(
1429
+ (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
1430
+ )
1431
+ if attention_mask is not None:
1432
+ attention_mask = self._prepare_decoder_attention_mask(
1433
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
1434
+ )
1435
+
1436
+ hidden_states = inputs_embeds
1437
+
1438
+ if self.gradient_checkpointing and self.training:
1439
+ if use_cache:
1440
+ use_cache = False
1441
+
1442
+ if use_cache and past_key_values is None:
1443
+ past_key_values = PlamoCache(self.config)
1444
+
1445
+ # decoder layers
1446
+ out = self.layers(
1447
+ DecoderInput(
1448
+ hidden_states,
1449
+ attention_mask,
1450
+ past_key_values,
1451
+ output_hidden_states,
1452
+ output_attentions,
1453
+ self.gradient_checkpointing,
1454
+ )
1455
+ )
1456
+ assert isinstance(out, DecoderOutput)
1457
+ hidden_states = out.hidden_states
1458
+ all_hidden_states = out.all_hidden_states
1459
+ all_self_attns = out.all_self_attns
1460
+
1461
+ hidden_states = self.norm(hidden_states)
1462
+
1463
+ # add hidden states from the last decoder layer
1464
+ if output_hidden_states:
1465
+ assert all_hidden_states is not None
1466
+ all_hidden_states += (hidden_states,)
1467
+
1468
+ if not return_dict:
1469
+ return tuple(
1470
+ v for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns] if v is not None
1471
+ )
1472
+ return BaseModelOutputWithPast(
1473
+ last_hidden_state=hidden_states,
1474
+ past_key_values=past_key_values,
1475
+ hidden_states=all_hidden_states,
1476
+ attentions=all_self_attns,
1477
+ )
1478
+
1479
+
1480
+ class PlamoForCausalLM(PlamoPreTrainedModel):
1481
+ _tied_weights_keys = ["lm_head.weight"]
1482
+
1483
+ # Without this, the model cannot be loaded into a meta device.
1484
+ # Relevant code:
1485
+ # https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/modeling_utils.py#L4376-L4381
1486
+ # https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/modeling_utils.py#L356
1487
+ # https://github.com/pytorch/pytorch/blob/v2.4.1/torch/nn/modules/module.py#L2068
1488
+ _supports_param_buffer_assignment = False
1489
+
1490
+ def __init__(self, config: PlamoConfig) -> None:
1491
+ super().__init__(config)
1492
+ self.model = PlamoModel(config)
1493
+
1494
+ self.vocab_size = config.vocab_size
1495
+ vocab_size = ((self.vocab_size + 15) // 16) * 16
1496
+ self.lm_head: torch.nn.Module = nn.Linear(config.hidden_size, vocab_size, bias=False)
1497
+
1498
+ # Initialize weights and apply final processing
1499
+ self.post_init()
1500
+
1501
+ def get_input_embeddings(self) -> torch.nn.Embedding:
1502
+ return self.model.embed_tokens
1503
+
1504
+ def set_input_embeddings(self, value: torch.nn.Embedding) -> None:
1505
+ self.model.embed_tokens = value
1506
+
1507
+ def get_output_embeddings(self) -> torch.nn.Module:
1508
+ return self.lm_head
1509
+
1510
+ def set_output_embeddings(self, new_embeddings: torch.nn.Module) -> None:
1511
+ self.lm_head = new_embeddings
1512
+
1513
+ def set_decoder(self, decoder: PlamoModel) -> None:
1514
+ self.model = decoder
1515
+
1516
+ def get_decoder(self) -> PlamoModel:
1517
+ return self.model
1518
+
1519
+ def forward( # type: ignore
1520
+ self,
1521
+ input_ids: Optional[torch.LongTensor] = None,
1522
+ attention_mask: Optional[torch.Tensor] = None,
1523
+ position_ids: Optional[torch.Tensor] = None,
1524
+ past_key_values: Optional[PlamoCache] = None,
1525
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1526
+ image_features: Optional[torch.Tensor] = None,
1527
+ labels: Optional[torch.LongTensor] = None,
1528
+ use_cache: Optional[bool] = None,
1529
+ output_attentions: Optional[bool] = None,
1530
+ output_hidden_states: Optional[bool] = None,
1531
+ return_dict: Optional[bool] = None,
1532
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
1533
+ r"""
1534
+ Args:
1535
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1536
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1537
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1538
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1539
+
1540
+ Returns:
1541
+
1542
+ Example:
1543
+
1544
+ ```python
1545
+ >>> from transformers import AutoTokenizer, LlamaForCausalLM
1546
+
1547
+ >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
1548
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
1549
+
1550
+ >>> prompt = "Hey, are you consciours? Can you talk to me?"
1551
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1552
+
1553
+ >>> # Generate
1554
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1555
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1556
+ "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
1557
+ ```"""
1558
+ assert input_ids is not None
1559
+
1560
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1561
+ output_hidden_states = (
1562
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1563
+ )
1564
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1565
+
1566
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1567
+ outputs = self.model(
1568
+ input_ids=input_ids,
1569
+ attention_mask=attention_mask,
1570
+ position_ids=position_ids,
1571
+ past_key_values=past_key_values,
1572
+ inputs_embeds=inputs_embeds,
1573
+ image_features=image_features,
1574
+ use_cache=use_cache,
1575
+ output_attentions=output_attentions,
1576
+ output_hidden_states=output_hidden_states,
1577
+ return_dict=return_dict,
1578
+ )
1579
+
1580
+ hidden_states = outputs[0]
1581
+ logits = self.lm_head(hidden_states)
1582
+ logits = logits[..., : self.vocab_size]
1583
+
1584
+ loss = None
1585
+ if labels is not None:
1586
+ # Shift so that tokens < n predict n
1587
+ shift_logits = logits[..., :-1, :].contiguous()
1588
+ shift_labels = labels[..., 1:].contiguous()
1589
+ # Flatten the tokens
1590
+ loss_fct = nn.CrossEntropyLoss()
1591
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1592
+ shift_labels = shift_labels.view(-1)
1593
+ # Enable model parallelism
1594
+ shift_labels = shift_labels.to(shift_logits.device)
1595
+ loss = loss_fct(shift_logits, shift_labels)
1596
+
1597
+ if not return_dict:
1598
+ output = (logits,) + outputs[1:]
1599
+ return (loss,) + output if loss is not None else output
1600
+
1601
+ return CausalLMOutputWithPast(
1602
+ loss=loss,
1603
+ logits=logits,
1604
+ past_key_values=outputs.past_key_values,
1605
+ hidden_states=outputs.hidden_states,
1606
+ attentions=outputs.attentions,
1607
+ )
1608
+
1609
+ def prepare_inputs_for_generation(
1610
+ self,
1611
+ input_ids: torch.Tensor,
1612
+ past_key_values: Optional[PlamoCache] = None,
1613
+ attention_mask: Optional[torch.Tensor] = None,
1614
+ inputs_embeds: Optional[torch.Tensor] = None,
1615
+ image_features: Optional[torch.Tensor] = None,
1616
+ **kwargs: Any,
1617
+ ) -> Dict[str, Any]:
1618
+ if past_key_values:
1619
+ input_ids = input_ids[:, -1:]
1620
+ if image_features is not None:
1621
+ image_features = image_features[:, -1:, :]
1622
+
1623
+ position_ids = kwargs.get("position_ids", None)
1624
+ if attention_mask is not None and position_ids is None:
1625
+ # create position_ids on the fly for batch generation
1626
+ position_ids = attention_mask.long().cumsum(-1) - 1
1627
+ position_ids.masked_fill_(attention_mask == 0, 1)
1628
+ if past_key_values:
1629
+ position_ids = position_ids[:, -1].unsqueeze(-1)
1630
+
1631
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1632
+ if inputs_embeds is not None and past_key_values is None:
1633
+ model_inputs: Dict[str, Any] = {"inputs_embeds": inputs_embeds}
1634
+ else:
1635
+ model_inputs = {"input_ids": input_ids}
1636
+
1637
+ model_inputs.update(
1638
+ {
1639
+ "position_ids": position_ids,
1640
+ "past_key_values": past_key_values,
1641
+ "use_cache": kwargs.get("use_cache"),
1642
+ "attention_mask": attention_mask,
1643
+ "image_features": image_features,
1644
+ }
1645
+ )
1646
+ return model_inputs
1647
+
1648
+ @staticmethod
1649
+ def _reorder_cache(past_key_values: PlamoCache, beam_idx: torch.Tensor) -> PlamoCache:
1650
+ past_key_values.reorder_cache(beam_idx)
1651
+ return past_key_values
1652
+
1653
+
1654
+ class MLPImageProjector(nn.Module):
1655
+ def __init__(self, config: PlamoConfig) -> None:
1656
+ super().__init__()
1657
+ self.config = config
1658
+
1659
+ assert config.image_feature_size is not None # for typing
1660
+
1661
+ # nn.LayerNorm is not supported by PFVM, so use RMSNorm + Bias instead to approximate this.
1662
+ self.norm0 = RMSNorm(config.image_feature_size, eps=config.rms_norm_eps)
1663
+ self.bias0 = Bias(config.image_feature_size)
1664
+
1665
+ # PFVM doesn't support Linear with bias, so add bias manually afterwards.
1666
+ self.linear1 = nn.Linear(config.image_feature_size, config.hidden_size, bias=False)
1667
+ self.bias1 = Bias(config.hidden_size)
1668
+ self.act1 = nn.GELU()
1669
+
1670
+ self.linear2 = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
1671
+ self.bias2 = Bias(config.hidden_size)
1672
+
1673
+ def forward(
1674
+ self,
1675
+ hidden_states: torch.Tensor,
1676
+ ) -> torch.Tensor:
1677
+ hidden_states = self.norm0(hidden_states)
1678
+ hidden_states = self.bias0(hidden_states)
1679
+
1680
+ hidden_states = self.linear1(hidden_states)
1681
+ hidden_states = self.bias1(hidden_states)
1682
+ hidden_states = self.act1(hidden_states)
1683
+
1684
+ hidden_states = self.linear2(hidden_states)
1685
+ hidden_states = self.bias2(hidden_states)
1686
+
1687
+ return hidden_states
1688
+
1689
+
1690
+ class Bias(nn.Module):
1691
+ def __init__(self, num_features: int) -> None:
1692
+ super().__init__()
1693
+ self._bias = nn.Parameter(torch.zeros((num_features,)))
1694
+
1695
+ def forward(
1696
+ self,
1697
+ x: torch.Tensor,
1698
+ ) -> torch.Tensor:
1699
+ return x + self._bias
special_tokens_map.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<|plamo:bos|>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "<|plamo:eos|>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": {
17
+ "content": "<|plamo:pad|>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "unk_token": {
24
+ "content": "<|plamo:unk|>",
25
+ "lstrip": false,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ }
30
+ }
tokenization_plamo.py ADDED
@@ -0,0 +1,392 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import math
3
+ import os
4
+ from shutil import copyfile
5
+ from typing import Any, Optional, Tuple
6
+
7
+ import numpy as np
8
+
9
+ # NOTE: numba does not support type hints for njit: https://github.com/python/mypy/issues/16149
10
+ from numba import njit # type: ignore[attr-defined]
11
+ from numba.core import types
12
+ from numba.typed import Dict, List
13
+ from transformers.tokenization_utils import PreTrainedTokenizer
14
+ from transformers.utils import logging
15
+
16
+ VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.jsonl"}
17
+ logger = logging.get_logger(__name__)
18
+
19
+ INVALID_SCORE = -20000000
20
+ UNKNOWN_SCORE = -10000000
21
+
22
+ TABLE_PIECE_LENGTH = 0
23
+ TABLE_TOKEN_ID = 1
24
+ TABLE_SCORE = 2
25
+ TABLE_PIECE_ID = 3
26
+
27
+ PATH_TOKEN_LENGTH = 0
28
+ PATH_TOKEN_ID = 1
29
+ PATH_NUM_TOKENS = 2
30
+
31
+
32
+ class AhoCorasick:
33
+ def __init__(self) -> None:
34
+ # List of tokens in the vocabulary.
35
+ self._tokens: list[str]
36
+
37
+ # A mapping from a byte code point to a token ID, used for byte fallback.
38
+ self._bytes: np.ndarray
39
+
40
+ # A mapping from a suffix's piece code to a suffix ID.
41
+ #
42
+ # Typically, the Aho-Corasick algorithm builds a Trie and adds suffix links between nodes
43
+ # of the Trie. In this implementation, a suffix ID corresponds to a node in the trie, and
44
+ # a piece code to an edge (in other words, a pair of a node and the next character).
45
+ #
46
+ # A piece code is a 64-bit integer:
47
+ # - The upper 32 bits store the Unicode code point of the first character.
48
+ # - The lower 32 bits store the suffix ID of the remaining suffix.
49
+ #
50
+ # A suffix ID is an integer indicating the starting position in the _table.
51
+ self._to_suffix_id: Dict[types.int64, types.int32]
52
+
53
+ # Flattened table representing the Trie structure for the Aho-Corasick algorithm.
54
+ # It stores information including scores for each piece (prefix) within each suffix.
55
+ # It is flattened for memory efficiency and performance. Suffixes are stored in
56
+ # lexicographical order of their reversed strings, which improves memory access locality
57
+ # when exploring new characters starting from the string's end. Pieces within a suffix are
58
+ # stored in the decreasing order of their lengths.
59
+ #
60
+ # Each piece (a prefix fo the suffix) contains four pieces of information:
61
+ # - TABLE_PIECE_LENGTH: Length of the piece.
62
+ # - TABLE_TOKEN_ID: Token ID (or -1 if the piece is not a valid token).
63
+ # - TABLE_SCORE: Score (or INVALID_SCORE if the piece is not a valid token).
64
+ # - TABLE_PIECE_ID: Piece ID of the suffix.
65
+ #
66
+ # Each suffix also includes a sentinel row with a length of 1, a score of UNKNOWN_SCORE,
67
+ # and a token ID of -1. Sentinel rows are identified by the score being UNKNOWN_SCORE.
68
+ self._table: np.ndarray
69
+
70
+ def build(self, vocab: list[Any]) -> None:
71
+ self._bytes = np.zeros(256, dtype=np.int32)
72
+ self._to_suffix_id = Dict.empty(key_type=types.int64, value_type=types.int32)
73
+
74
+ # Build suffix_to_score and token_to_token_id.
75
+ # The suffix_to_score dictionary maps a suffix to its score. It also includes all suffixes
76
+ # of the token for the Trie structure for the Aho-Corasick algorithm. If a suffix is not a
77
+ # valid token, its score is set to math.nan.
78
+ # The token_to_token_id dictionary maps a token to its token ID.
79
+ suffix_to_score: dict[str, float] = {}
80
+ token_to_token_id: dict[str, int] = {}
81
+ self._tokens = []
82
+ for token_id, row in enumerate(vocab):
83
+ assert isinstance(row[0], str), row
84
+ assert isinstance(row[1], (int, float)), row
85
+
86
+ token = str(row[0])
87
+ self._tokens.append(token)
88
+ token_to_token_id[token] = token_id
89
+
90
+ # Special handling for byte tokens.
91
+ if len(row) > 2 and row[2] == "BYTE":
92
+ assert len(token) == 6 and token.startswith("<0x") and token.endswith(">"), row[0]
93
+ self._bytes[int(row[0][3:5], 16)] = token_id
94
+ continue
95
+
96
+ suffix_to_score[token] = float(row[1])
97
+ # Ensure that all suffixes are included in suffix_to_score.
98
+ for i in range(1, len(token)):
99
+ suffix_to_score[token[i:]] = suffix_to_score.get(token[i:], math.nan)
100
+
101
+ # Ensure all byte tokens are set.
102
+ for i in range(256):
103
+ assert self._bytes[i] != 0, f"Byte token for <0x{i:02X}> is not set."
104
+
105
+ # List suffixes in lexicographical order of their reversed strings.
106
+ suffixes = list(suffix_to_score.keys())
107
+ suffixes.append("")
108
+ suffixes.sort(key=lambda x: x[::-1])
109
+
110
+ # Build suffix_to_id, which is a mapping from a suffix to a suffix ID, and _to_suffix_id,
111
+ # which is a mapping from a piece code to a suffix ID.
112
+ suffix_to_id: dict[str, int] = {}
113
+ num_pieces = 0
114
+ for s in suffixes:
115
+ suffix_to_id[s] = num_pieces
116
+ if s != "":
117
+ self._to_suffix_id[ord(s[0]) << 32 | suffix_to_id[s[1:]]] = np.int32(num_pieces)
118
+ num_pieces += 1 + sum(s[:i] in suffix_to_score for i in range(1, len(s) + 1))
119
+ assert suffix_to_id[""] == 0, suffix_to_id[""]
120
+
121
+ # Build _table, which is a flattened table representing the Trie structure for the Aho-Corasick.
122
+ self._table = np.zeros((num_pieces, 4), dtype=np.int32)
123
+ i = 0
124
+ for suffix in suffixes:
125
+ # Add all prefixes of the suffix to the table.
126
+ for piece_length in range(len(suffix), 0, -1):
127
+ piece = suffix[:piece_length]
128
+ score = suffix_to_score.get(piece, None)
129
+ if score is None:
130
+ continue
131
+ self._table[i, TABLE_PIECE_LENGTH] = piece_length
132
+ self._table[i, TABLE_TOKEN_ID] = token_to_token_id.get(piece, -1)
133
+ self._table[i, TABLE_SCORE] = round(score * 1e4) if math.isfinite(score) else INVALID_SCORE
134
+ self._table[i, TABLE_PIECE_ID] = suffix_to_id[piece]
135
+ i += 1
136
+
137
+ # Add a sentinel row.
138
+ self._table[i, TABLE_PIECE_LENGTH] = 1
139
+ self._table[i, TABLE_TOKEN_ID] = -1
140
+ self._table[i, TABLE_SCORE] = UNKNOWN_SCORE
141
+ i += 1
142
+ assert i == num_pieces, (i, num_pieces)
143
+
144
+ @staticmethod
145
+ @njit
146
+ def _encode(
147
+ to_suffix_id: Dict[types.int64, types.int32],
148
+ table: np.ndarray,
149
+ bytes: np.ndarray,
150
+ data: np.ndarray,
151
+ ) -> np.ndarray:
152
+ # Initialize scores array with a high value and set the score at the end to 0.
153
+ # This array keeps track of the minimum cost (best score) to encode from each position to the end.
154
+ scores = np.full((len(data) + 1,), 2**60, dtype=np.int64)
155
+ scores[-1] = 0
156
+
157
+ # Path array to store the best path information.
158
+ # The path array keeps track of token length, token ID, and number of tokens needed to encode.
159
+ path = np.zeros((len(data) + 1, 3), dtype=np.int32)
160
+
161
+ # Initialize suffix_id to 0, which represents the root of the Trie.
162
+ suffix_id = 0
163
+
164
+ # Process the input data from the end to the beginning.
165
+ for i in range(len(data) - 1, -1, -1):
166
+ c = data[i]
167
+
168
+ # Find the next suffix ID by iterating the suffix IDs of prefixes of the current suffix.
169
+ # NOTE: If no suffix ID is found, suffix_id will be set to 0.
170
+ for p in range(suffix_id, len(table)):
171
+ suffix_id = to_suffix_id.get(c << 32 | table[p, TABLE_PIECE_ID], np.int32(0))
172
+ # If a next suffix ID is found or a sentinel row is reached, break the loop.
173
+ if suffix_id > 0 or table[p, TABLE_SCORE] == UNKNOWN_SCORE:
174
+ break
175
+
176
+ # Update the best path to the current position. If multiple paths have the same score,
177
+ # this chooses the longest prefix as the best path (table is sorted in the decreasing
178
+ # order of piece length).
179
+ for p in range(suffix_id, len(table)):
180
+ score = table[p, TABLE_SCORE]
181
+ if score > INVALID_SCORE:
182
+ piece_length = table[p, TABLE_PIECE_LENGTH]
183
+ s = scores[i + piece_length] - score
184
+ if s < scores[i]:
185
+ scores[i] = s
186
+ path[i, PATH_TOKEN_LENGTH] = piece_length
187
+ path[i, PATH_TOKEN_ID] = table[p, TABLE_TOKEN_ID]
188
+ path[i, PATH_NUM_TOKENS] = path[i + piece_length, PATH_NUM_TOKENS] + 1
189
+ if score == UNKNOWN_SCORE:
190
+ # Add number of bytes to represent `c` in UTF-8 (minus 1; 1 is already
191
+ # added above).
192
+ path[i, PATH_NUM_TOKENS] += (c >= 0x80) + (c >= 0x800) + (c >= 0x10000)
193
+
194
+ # If it reaches a sentinel row, break the loop.
195
+ if score == UNKNOWN_SCORE:
196
+ break
197
+
198
+ # Decode the best path from the beginning to get the token IDs.
199
+ pos = 0
200
+ token_ids = np.zeros(path[0, PATH_NUM_TOKENS], dtype=np.int32)
201
+ token_pos = 0
202
+ while pos < len(data):
203
+ if path[pos, PATH_TOKEN_ID] >= 0:
204
+ token_ids[token_pos] = path[pos, PATH_TOKEN_ID]
205
+ token_pos += 1
206
+ else:
207
+ # Fall back to byte tokens.
208
+ c = data[pos]
209
+ s = 1 + (c >= 0x80) + (c >= 0x800) + (c >= 0x10000)
210
+ # Add byte tokens representing UTF-8 bytes.
211
+ for i in range(s):
212
+ b = c if s == 1 else (0xF00 >> s) & 0xFF if i == 0 else 0x80
213
+ token_ids[token_pos] = bytes[b | ((c >> (s - i - 1) * 6) & 0x3F)]
214
+ token_pos += 1
215
+
216
+ # Ensure that pos should increase by at least 1.
217
+ assert path[pos, PATH_TOKEN_LENGTH] > 0, (pos, path[pos])
218
+ pos += path[pos, PATH_TOKEN_LENGTH]
219
+
220
+ return token_ids
221
+
222
+ def encode(self, data: str) -> np.ndarray:
223
+ """Encodes a string into a sequence of token IDs."""
224
+ return np.asarray(
225
+ self._encode(
226
+ self._to_suffix_id,
227
+ self._table,
228
+ self._bytes,
229
+ # Convert a string into a numpy array of Unicode code points.
230
+ # NOTE: This skips UTF-32 BOM.
231
+ np.frombuffer(data.encode("utf-32"), dtype=np.int32)[1:],
232
+ )
233
+ )
234
+
235
+ def encode_as_tokens(self, data: str) -> list[str]:
236
+ """Encodes a string into a sequence of tokens."""
237
+ return [self._tokens[token_id] for token_id in self.encode(data)]
238
+
239
+
240
+ class PlamoTokenizer(PreTrainedTokenizer): # type: ignore
241
+ vocab_files_names = VOCAB_FILES_NAMES
242
+ model_input_names = ["input_ids", "attention_mask"]
243
+
244
+ _save_files = [
245
+ "special_tokens_map.json",
246
+ "tokenization_plamo.py",
247
+ "tokenizer.jsonl",
248
+ "tokenizer_config.json",
249
+ ]
250
+
251
+ def __init__(
252
+ self,
253
+ vocab_file: str,
254
+ unk_token: str = "<|plamo:unk|>",
255
+ bos_token: str = "<|plamo:bos|>",
256
+ eos_token: str = "<|plamo:eos|>",
257
+ pad_token: str = "<|plamo:pad|>",
258
+ cls_token: Optional[str] = None,
259
+ sep_token: Optional[str] = None,
260
+ mask_token: Optional[str] = None,
261
+ clean_up_tokenization_spaces: bool = False,
262
+ **kwargs: Any,
263
+ ) -> None:
264
+ """Tokenizer for PLaMo.
265
+
266
+ Args:
267
+ vocab_file (str): Vocabrary file path.
268
+ unk_token (str): Unknown token.
269
+ bos_token (str): Beginning of sentence token.
270
+ eos_token (str): End of sentence token.
271
+ pad_token (str): Padding token.
272
+ cls_token (str):
273
+ Classification token, to extract a summary of an input sequence leveraging self-attention along the
274
+ full depth of the model.
275
+ sep_token (str): Separation token, to separate context and query in an input sequence.
276
+ mask_token (str): Mask token, to use when training a model with masked-language modeling.
277
+ clean_up_tokenization_spaces (bool): Whether or not to clean up the tokenization spaces.
278
+ num_threads (int):
279
+ Number of threads. This value will be ignored if one of `PLAMO_TOKENIZER_NUM_THREADS` or
280
+ `RAYON_NUM_THREADS` is set as an environment variable.
281
+ """
282
+ if "add_bos_token" not in kwargs:
283
+ kwargs["add_bos_token"] = False
284
+ if "add_eos_token" not in kwargs:
285
+ kwargs["add_eos_token"] = False
286
+ self.data: list[Any] = [json.loads(line) for line in open(vocab_file, "r", encoding="utf-8")]
287
+ self.vocab: dict[str, int] = {v[0]: i for i, v in enumerate(self.data)}
288
+ self.aho_corasick = AhoCorasick()
289
+ self.aho_corasick.build(self.data)
290
+ self.vocab_file = vocab_file
291
+ self.add_bos_token = kwargs["add_bos_token"]
292
+ self.add_eos_token = kwargs["add_eos_token"]
293
+
294
+ super().__init__(
295
+ vocab_file=vocab_file,
296
+ unk_token=unk_token,
297
+ bos_token=bos_token,
298
+ eos_token=eos_token,
299
+ pad_token=pad_token,
300
+ cls_token=cls_token,
301
+ sep_token=sep_token,
302
+ mask_token=mask_token,
303
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
304
+ **kwargs,
305
+ )
306
+
307
+ # the functions below are copied from hf transformers LlamaTokenizer's implementation to fix the behaviour of the tokenizer
308
+ # https://github.com/huggingface/transformers/blob/v4.30.2/src/transformers/models/llama/tokenization_llama.py
309
+
310
+ def __getstate__(self) -> dict[str, Any]:
311
+ state = self.__dict__.copy()
312
+ state["aho_corasick"] = None
313
+ return state
314
+
315
+ def __setstate__(self, d: dict[str, Any]) -> None:
316
+ self.__dict__ = d
317
+ self.aho_corasick = AhoCorasick()
318
+ self.aho_corasick.build(self.data)
319
+
320
+ @property
321
+ def vocab_size(self) -> Any:
322
+ """Returns vocab size"""
323
+ return len(self.data)
324
+
325
+ def token_to_score(self, token: str) -> Optional[float]:
326
+ """Returns score of the token"""
327
+ token_id = self.vocab.get(token, None)
328
+ return None if token_id is None else self.data[token_id][1]
329
+
330
+ def get_vocab(self) -> dict[str, int]:
331
+ """Returns vocab as a dict"""
332
+ vocab = self.vocab.copy()
333
+ vocab.update(self.added_tokens_encoder)
334
+ return vocab
335
+
336
+ def convert_tokens_to_string(self, tokens: List[str]) -> str:
337
+ """Converts a sequence of tokens (string) in a single string."""
338
+ return b"".join(
339
+ [bytes([int(t[3:5], 16)]) if t.startswith("<0x") else t.encode("utf-8") for t in tokens]
340
+ ).decode("utf-8", errors="replace")
341
+
342
+ def _tokenize(self, text: str) -> Any:
343
+ """Returns a tokenized string."""
344
+ return self.aho_corasick.encode_as_tokens(text)
345
+
346
+ def _convert_token_to_id(self, token: str) -> Any:
347
+ """Converts a token (str) in an id using the vocab."""
348
+ return self.vocab.get(token, 0)
349
+
350
+ def _convert_id_to_token(self, index: int) -> Any:
351
+ """Converts an index (integer) in a token (str) using the vocab."""
352
+ return self.data[index][0]
353
+
354
+ def build_inputs_with_special_tokens(
355
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
356
+ ) -> List[int]:
357
+ bos_token_id = [self.bos_token_id] if self.add_bos_token else []
358
+ eos_token_id = [self.eos_token_id] if self.add_eos_token else []
359
+
360
+ output = bos_token_id + token_ids_0 + eos_token_id
361
+
362
+ if token_ids_1 is not None:
363
+ output = output + bos_token_id + token_ids_1 + eos_token_id
364
+
365
+ return output
366
+
367
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
368
+ """
369
+ Save the vocabulary and special tokens file to a directory.
370
+
371
+ Args:
372
+ save_directory (`str`):
373
+ The directory in which to save the vocabulary.
374
+
375
+ Returns:
376
+ `Tuple(str)`: Paths to the files saved.
377
+ """
378
+ if not os.path.isdir(save_directory):
379
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
380
+ return ("",)
381
+ out_vocab_file = os.path.join(
382
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
383
+ )
384
+
385
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
386
+ copyfile(self.vocab_file, out_vocab_file)
387
+ elif not os.path.isfile(self.vocab_file):
388
+ with open(out_vocab_file, "w") as f:
389
+ for token in self.data:
390
+ print(json.dumps(token, ensure_ascii=False), file=f)
391
+
392
+ return (out_vocab_file,)
tokenizer.jsonl ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": true,
3
+ "add_eos_token": false,
4
+ "added_tokens_decoder": {
5
+ "0": {
6
+ "content": "<|plamo:unk|>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false,
11
+ "special": true
12
+ },
13
+ "1": {
14
+ "content": "<|plamo:bos|>",
15
+ "lstrip": false,
16
+ "normalized": false,
17
+ "rstrip": false,
18
+ "single_word": false,
19
+ "special": true
20
+ },
21
+ "2": {
22
+ "content": "<|plamo:eos|>",
23
+ "lstrip": false,
24
+ "normalized": false,
25
+ "rstrip": false,
26
+ "single_word": false,
27
+ "special": true
28
+ },
29
+ "3": {
30
+ "content": "<|plamo:pad|>",
31
+ "lstrip": false,
32
+ "normalized": false,
33
+ "rstrip": false,
34
+ "single_word": false,
35
+ "special": true
36
+ }
37
+ },
38
+ "auto_map": {
39
+ "AutoTokenizer": [
40
+ "tokenization_plamo.PlamoTokenizer",
41
+ null
42
+ ]
43
+ },
44
+ "bos_token": "<|plamo:bos|>",
45
+ "clean_up_tokenization_spaces": false,
46
+ "cls_token": null,
47
+ "eos_token": "<|plamo:eos|>",
48
+ "local_file_only": true,
49
+ "mask_token": null,
50
+ "model_max_length": 1000000000000000019884624838656,
51
+ "pad_token": "<|plamo:pad|>",
52
+ "sep_token": null,
53
+ "tokenizer_class": "PlamoTokenizer",
54
+ "unk_token": "<|plamo:unk|>"
55
+ }