Safetensors
vmistral
custom_code
jiang719 commited on
Commit
2fed580
1 Parent(s): 8039218

Upload folder using huggingface_hub

Browse files
__init__.py ADDED
File without changes
config.json ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_flash_attn_2_enabled": true,
3
+ "_name_or_path": "None",
4
+ "additional_vocab_size": 2,
5
+ "alpha_initializer": "zeros",
6
+ "alpha_type": "float",
7
+ "alphas_initializer_range": 0.0,
8
+ "architectures": [
9
+ "WebForVisionText2Text"
10
+ ],
11
+ "attention_dropout": 0.0,
12
+ "auto_map": {
13
+ "AutoConfig": "configuration_vmistral.VMistralConfig",
14
+ "AutoModelForCausalLM": "modeling_web.WebForVisionText2Text"
15
+ },
16
+ "bos_token_id": 1,
17
+ "cross_layer_interval": 1,
18
+ "eos_token_id": 2,
19
+ "freeze_lm_head": false,
20
+ "freeze_text_layers": false,
21
+ "freeze_text_module_exceptions": [],
22
+ "freeze_vision_layers": false,
23
+ "freeze_vision_module_exceptions": [],
24
+ "hidden_act": "silu",
25
+ "hidden_size": 4096,
26
+ "image_token_id": 32001,
27
+ "initializer_range": 0.02,
28
+ "intermediate_size": 14336,
29
+ "max_position_embeddings": 32768,
30
+ "model_type": "vmistral",
31
+ "num_attention_heads": 32,
32
+ "num_hidden_layers": 32,
33
+ "num_key_value_heads": 8,
34
+ "pad_token_id": 0,
35
+ "perceiver_config": {
36
+ "model_type": "vmistral",
37
+ "qk_layer_norms_perceiver": true,
38
+ "resampler_depth": 3
39
+ },
40
+ "qk_layer_norms": true,
41
+ "rms_norm_eps": 1e-05,
42
+ "rope_theta": 10000.0,
43
+ "sliding_window": 4096,
44
+ "tie_word_embeddings": false,
45
+ "torch_dtype": "bfloat16",
46
+ "transformers_version": "4.41.1",
47
+ "use_cache": true,
48
+ "use_resampler": true,
49
+ "vision_config": {
50
+ "hidden_size": 1152,
51
+ "image_size": 960,
52
+ "intermediate_size": 4304,
53
+ "model_type": "vmistral",
54
+ "num_attention_heads": 16,
55
+ "num_hidden_layers": 27,
56
+ "patch_size": 14
57
+ },
58
+ "vocab_size": 32000,
59
+ "web_attention_range": 2
60
+ }
configuration_vmistral.py ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ VMistral model configuration"""
16
+ from transformers.configuration_utils import PretrainedConfig
17
+ from transformers.utils import logging
18
+
19
+
20
+ logger = logging.get_logger(__name__)
21
+
22
+ MISTRAL_PRETRAINED_CONFIG_ARCHIVE_MAP = {
23
+ "lt-asset/Waffle_VLM_WebSight": "https://huggingface.co/lt-asset/Waffle_VLM_WebSight/blob/main/configuration_vmistral.py",
24
+ }
25
+
26
+
27
+ class VMistralVisionConfig(PretrainedConfig):
28
+ r"""
29
+ """
30
+ model_type = "vmistral"
31
+
32
+ def __init__(
33
+ self,
34
+ hidden_size=768,
35
+ intermediate_size=3072,
36
+ num_hidden_layers=12,
37
+ num_attention_heads=12,
38
+ num_channels=3,
39
+ image_size=224,
40
+ patch_size=32,
41
+ hidden_act="gelu_pytorch_tanh",
42
+ layer_norm_eps=1e-6,
43
+ attention_dropout=0.0,
44
+ initializer_range=0.02,
45
+ initializer_factor=1.0,
46
+ web_attention_range=1,
47
+ _flash_attn_2_enabled=True,
48
+ **kwargs,
49
+ ):
50
+ super().__init__(**kwargs)
51
+
52
+ self.hidden_size = hidden_size
53
+ self.intermediate_size = intermediate_size
54
+ self.num_hidden_layers = num_hidden_layers
55
+ self.num_attention_heads = num_attention_heads
56
+ self.num_channels = num_channels
57
+ self.patch_size = patch_size
58
+ self.image_size = image_size
59
+ self.initializer_range = initializer_range
60
+ self.initializer_factor = initializer_factor
61
+ self.attention_dropout = attention_dropout
62
+ self.layer_norm_eps = layer_norm_eps
63
+ self.hidden_act = hidden_act
64
+ self.web_attention_range = web_attention_range
65
+ self._flash_attn_2_enabled = _flash_attn_2_enabled
66
+
67
+
68
+ class VMistralPerceiverConfig(PretrainedConfig):
69
+ r"""
70
+ TThis is the configuration class to store the configuration of a [`MistralModel`]. It is used to instantiate an
71
+ Mistral model according to the specified arguments, defining the model architecture. Instantiating a configuration
72
+ with the defaults will yield a similar configuration to that of the Mistral-7B-v0.1 or Mistral-7B-Instruct-v0.1.
73
+
74
+ [mistralai/Mistral-7B-v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1)
75
+ [mistralai/Mistral-7B-Instruct-v0.1](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1)
76
+
77
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
78
+ documentation from [`PretrainedConfig`] for more information.
79
+
80
+ Args:
81
+ use_resampler (`bool`, *optional*, defaults to `False`):
82
+ Whether or not to use the resampler
83
+ resampler_n_latents (`int`, *optional*, defaults to ):
84
+ Number of latent embeddings to resample ("compress") the input sequence to (usually < 128).
85
+ resampler_depth (`int`, *optional*, defaults to 6):
86
+ Depth of the Perceiver Resampler (Transformer w/ cross attention). Should be shallow (< 3).
87
+ resampler_n_heads (`int`, *optional*, defaults to 16):
88
+ Number of heads in each Transformer block (for multi-headed self-attention).
89
+ resampler_head_dim (`int`, *optional*, defaults to 96):
90
+ Dimensionality of each head projection in the Transformer block.
91
+ qk_layer_norms_perceiver (`bool`, *optional*, defaults to `False`):
92
+ Whether or not to use qk layer norms in perceiver
93
+ """
94
+ model_type = "vmistral"
95
+
96
+ def __init__(
97
+ self,
98
+ resampler_n_latents=64,
99
+ resampler_depth=6,
100
+ resampler_n_heads=16,
101
+ resampler_head_dim=96,
102
+ qk_layer_norms_perceiver=False,
103
+ **kwargs,
104
+ ):
105
+ self.resampler_n_latents = resampler_n_latents
106
+ self.resampler_depth = resampler_depth
107
+ self.resampler_n_heads = resampler_n_heads
108
+ self.resampler_head_dim = resampler_head_dim
109
+ self.qk_layer_norms_perceiver = qk_layer_norms_perceiver
110
+
111
+ super().__init__(**kwargs)
112
+
113
+
114
+ class VMistralConfig(PretrainedConfig):
115
+ r"""
116
+ This is the configuration class to store the configuration of a [`MistralModel`]. It is used to instantiate an
117
+ Mistral model according to the specified arguments, defining the model architecture. Instantiating a configuration
118
+ with the defaults will yield a similar configuration to that of the Mistral-7B-v0.1 or Mistral-7B-Instruct-v0.1.
119
+
120
+ [mistralai/Mistral-7B-v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1)
121
+ [mistralai/Mistral-7B-Instruct-v0.1](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1)
122
+
123
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
124
+ documentation from [`PretrainedConfig`] for more information.
125
+
126
+ Args:
127
+ additional_vocab_size (`int`, *optional`, defaults to 0):
128
+ Additional vocabulary size of the model, typically for the special "<img>" token. Additional vocab tokens
129
+ are always trainable whereas regular vocab tokens can be frozen or not.
130
+ vocab_size (`int`, *optional*, defaults to 32000):
131
+ Vocabulary size of the Mistral model. Defines the number of different tokens that can be represented by the
132
+ `inputs_ids` passed when calling [`MistralModel`]
133
+ hidden_size (`int`, *optional*, defaults to 4096):
134
+ Dimension of the hidden representations.
135
+ intermediate_size (`int`, *optional*, defaults to 14336):
136
+ Dimension of the MLP representations.
137
+ num_hidden_layers (`int`, *optional*, defaults to 32):
138
+ Number of hidden layers in the Transformer encoder.
139
+ num_attention_heads (`int`, *optional*, defaults to 32):
140
+ Number of attention heads for each attention layer in the Transformer encoder.
141
+ num_key_value_heads (`int`, *optional*, defaults to 8):
142
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
143
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
144
+ `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
145
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
146
+ by meanpooling all the original heads within that group. For more details checkout [this
147
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`.
148
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
149
+ The non-linear activation function (function or string) in the decoder.
150
+ max_position_embeddings (`int`, *optional*, defaults to `4096*32`):
151
+ The maximum sequence length that this model might ever be used with. Mistral's sliding window attention
152
+ allows sequence of up to 4096*32 tokens.
153
+ initializer_range (`float`, *optional*, defaults to 0.02):
154
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
155
+ alpha_initializer (`str`, *optional*, defaults to `"zeros"`):
156
+ Initialization type for the alphas.
157
+ alphas_initializer_range (`float`, *optional*, defaults to 0.0):
158
+ The standard deviation of the truncated_normal_initializer for initializing the alphas in the Gated Cross
159
+ Attention.
160
+ alpha_type (`str`, *optional*, defaults to `"float"`):
161
+ Whether the gating alphas should be vectors or single floats.
162
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
163
+ The epsilon used by the rms normalization layers.
164
+ use_cache (`bool`, *optional*, defaults to `True`):
165
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
166
+ relevant if `config.is_decoder=True`.
167
+ pad_token_id (`int`, *optional*):
168
+ The id of the padding token.
169
+ bos_token_id (`int`, *optional*, defaults to 1):
170
+ The id of the "beginning-of-sequence" token.
171
+ eos_token_id (`int`, *optional*, defaults to 2):
172
+ The id of the "end-of-sequence" token.
173
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
174
+ Whether the model's input and output word embeddings should be tied.
175
+ rope_theta (`float`, *optional*, defaults to 10000.0):
176
+ The base period of the RoPE embeddings.
177
+ sliding_window (`int`, *optional*, defaults to 4096):
178
+ Sliding window attention window size. If not specified, will default to `4096`.
179
+ cross_layer_interval (`int`, *optional*, default to 1)
180
+ Interval for cross attention (from text to image) layers.
181
+ qk_layer_norms (`bool`, *optional*, defaults to `False`): Whether to add layer norm after q and k
182
+ freeze_text_layers (`bool`, *optional*, defaults to `True`): Whether to freeze text layers
183
+ freeze_text_module_exceptions (`bool`, *optional*, defaults to `[]`):
184
+ Exceptions to freezing text layers when `freeze_text_layers` is `True`
185
+ freeze_lm_head (`bool`, *optional*, defaults to `False`): Whether to freeze lm head
186
+ freeze_vision_layers (`bool`, *optional*, defaults to `True`): Whether to freeze vision layers
187
+ freeze_vision_module_exceptions (`bool`, *optional*, defaults to `[]`):
188
+ Exceptions to freezing vision layers when `freeze_vision_layers` is `True`
189
+ use_resampler (`bool`, *optional*, defaults to `False`): Whether to use the Resampler
190
+ vision_config (`IdeficsVisionConfig`, *optional*): Custom vision config or dict
191
+ perceiver_config (`IdeficsPerceiverConfig`, *optional*): Custom perceiver config or dict
192
+
193
+ Example:
194
+ ```python
195
+ >>> from transformers import MistralModel, MistralConfig
196
+
197
+ >>> # Initializing a Mistral 7B style configuration
198
+ >>> configuration = MistralConfig()
199
+
200
+ >>> # Initializing a model from the Mistral 7B style configuration
201
+ >>> model = MistralModel(configuration)
202
+
203
+ >>> # Accessing the model configuration
204
+ >>> configuration = model.config
205
+ ```"""
206
+ model_type = "vmistral"
207
+ is_composition = False
208
+
209
+ def __init__(
210
+ self,
211
+ additional_vocab_size=0,
212
+ vocab_size=32000,
213
+ hidden_size=4096,
214
+ intermediate_size=14336,
215
+ num_hidden_layers=32,
216
+ num_attention_heads=32,
217
+ num_key_value_heads=8,
218
+ hidden_act="silu",
219
+ max_position_embeddings=4096 * 32,
220
+ initializer_range=0.02,
221
+ alpha_initializer="zeros",
222
+ alphas_initializer_range=0.0,
223
+ alpha_type="float",
224
+ rms_norm_eps=1e-6,
225
+ use_cache=True,
226
+ pad_token_id=0, # None in the original configuration_mistral, we set it to the unk_token_id
227
+ bos_token_id=1,
228
+ eos_token_id=2,
229
+ image_token_id=32_001,
230
+ tie_word_embeddings=False,
231
+ rope_theta=10000.0,
232
+ sliding_window=4096,
233
+ cross_layer_interval=1,
234
+ qk_layer_norms=False,
235
+ freeze_text_layers=True,
236
+ freeze_text_module_exceptions=[],
237
+ freeze_lm_head=False,
238
+ freeze_vision_layers=True,
239
+ freeze_vision_module_exceptions=[],
240
+ attention_dropout=0.0,
241
+ _flash_attn_2_enabled=True,
242
+ use_resampler=False,
243
+ vision_config=None,
244
+ perceiver_config=None,
245
+ **kwargs,
246
+ ):
247
+ self.vocab_size = vocab_size
248
+ self.additional_vocab_size = additional_vocab_size
249
+ self.image_token_id = image_token_id
250
+ self.max_position_embeddings = max_position_embeddings
251
+ self.hidden_size = hidden_size
252
+ self.intermediate_size = intermediate_size
253
+ self.num_hidden_layers = num_hidden_layers
254
+ self.num_attention_heads = num_attention_heads
255
+ self.sliding_window = sliding_window
256
+
257
+ # for backward compatibility
258
+ if num_key_value_heads is None:
259
+ num_key_value_heads = num_attention_heads
260
+
261
+ self.num_key_value_heads = num_key_value_heads
262
+ self.hidden_act = hidden_act
263
+ self.initializer_range = initializer_range
264
+ self.alpha_initializer = alpha_initializer
265
+ self.alphas_initializer_range = alphas_initializer_range
266
+ self.alpha_type = alpha_type
267
+ self.rms_norm_eps = rms_norm_eps
268
+ self.use_cache = use_cache
269
+ self.rope_theta = rope_theta
270
+
271
+ self.cross_layer_interval = cross_layer_interval
272
+ self.qk_layer_norms = qk_layer_norms
273
+ self.freeze_vision_layers = freeze_vision_layers
274
+
275
+ self.freeze_text_layers = freeze_text_layers
276
+ self.freeze_text_module_exceptions = freeze_text_module_exceptions
277
+ self.freeze_vision_module_exceptions = freeze_vision_module_exceptions
278
+ self.freeze_lm_head = freeze_lm_head
279
+
280
+ self.use_resampler = use_resampler
281
+ self._flash_attn_2_enabled = _flash_attn_2_enabled
282
+ self.attention_dropout = attention_dropout
283
+
284
+ if perceiver_config is None:
285
+ self.perceiver_config = VMistralPerceiverConfig()
286
+ elif isinstance(perceiver_config, dict):
287
+ self.perceiver_config = VMistralPerceiverConfig(**perceiver_config)
288
+ elif isinstance(perceiver_config, VMistralPerceiverConfig):
289
+ self.perceiver_config = perceiver_config
290
+
291
+ if vision_config is None:
292
+ self.vision_config = VMistralVisionConfig()
293
+ elif isinstance(vision_config, dict):
294
+ self.vision_config = VMistralVisionConfig(**vision_config)
295
+ elif isinstance(vision_config, VMistralVisionConfig):
296
+ self.vision_config = vision_config
297
+
298
+ super().__init__(
299
+ pad_token_id=pad_token_id,
300
+ bos_token_id=bos_token_id,
301
+ eos_token_id=eos_token_id,
302
+ tie_word_embeddings=tie_word_embeddings,
303
+ **kwargs,
304
+ )
305
+
306
+ # IMPORTANT: Do not do any __init__ args-based checks in the constructor, since
307
+ # PretrainedConfig.from_dict first instantiates the class with the config dict and only then
308
+ # updates the config object with `kwargs` from from_pretrained, so during the instantiation
309
+ # of this object many attributes have default values and haven't yet been overridden.
310
+ # Do any required checks inside `from_pretrained` once the superclass' `from_pretrained` was run.
generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "pad_token_id": 0,
6
+ "transformers_version": "4.41.1"
7
+ }
generation_utils.py ADDED
@@ -0,0 +1,376 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, Optional, List
2
+ import torch
3
+ from transformers import GenerationMixin
4
+ from transformers import AutoTokenizer
5
+ import re
6
+ import traceback
7
+
8
+
9
+ class WebGenerationMixin(GenerationMixin):
10
+ def _update_model_kwargs_for_generation(
11
+ self,
12
+ outputs,
13
+ model_kwargs: Dict[str, Any],
14
+ is_encoder_decoder: bool = False,
15
+ standardize_cache_format: bool = False,
16
+ ) -> Dict[str, Any]:
17
+ # update past_key_values
18
+
19
+ model_kwargs["past_key_values"] = self._extract_past_from_model_output(
20
+ outputs, standardize_cache_format=standardize_cache_format
21
+ )
22
+ if getattr(outputs, "state", None) is not None:
23
+ model_kwargs["state"] = outputs.state
24
+
25
+ # update token_type_ids with last value
26
+ if "token_type_ids" in model_kwargs:
27
+ token_type_ids = model_kwargs["token_type_ids"]
28
+ model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1)
29
+
30
+ if not is_encoder_decoder:
31
+ # update attention mask
32
+ if 'web_attention_mask' not in model_kwargs:
33
+ attention_mask = model_kwargs["attention_mask"]
34
+ model_kwargs['web_attention_mask'] = torch.tril(torch.ones((attention_mask.shape[-1], attention_mask.shape[-1]), dtype = attention_mask.dtype)).unsqueeze(0)
35
+
36
+ if "attention_mask" in model_kwargs:
37
+ attention_mask = model_kwargs["attention_mask"]
38
+ model_kwargs["attention_mask"] = torch.cat(
39
+ [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
40
+ )
41
+
42
+ model_kwargs['html_tree'] = outputs.html_tree
43
+
44
+ else:
45
+ # update decoder attention mask
46
+ if "decoder_attention_mask" in model_kwargs:
47
+ decoder_attention_mask = model_kwargs["decoder_attention_mask"]
48
+ model_kwargs["decoder_attention_mask"] = torch.cat(
49
+ [decoder_attention_mask, decoder_attention_mask.new_ones((decoder_attention_mask.shape[0], 1))],
50
+ dim=-1,
51
+ )
52
+
53
+ if "cache_position" in model_kwargs and model_kwargs["cache_position"] is not None:
54
+ model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + 1
55
+ return model_kwargs
56
+
57
+ def _reorder_cache(self, past_key_values, beam_idx):
58
+ raise NotImplementedError(
59
+ f"Make sure that a `_reorder_cache` function is correctly implemented in {self.__class__.__module__} to"
60
+ f" enable beam search for {self.__class__}"
61
+ )
62
+
63
+
64
+ class TreeNode():
65
+ def __init__(self,content: list, idx: int):
66
+ self.open_tag: List[str] = content
67
+ self.end_tag: Optional[List[str]] = None
68
+ self.self_closing_tag: Optional[List[str]] = None
69
+ self.text = ""
70
+
71
+ self.name: Optional[str] = None
72
+ self.parent: Optional['TreeNode'] = None # Use 'TreeNode' as a string for forward reference
73
+
74
+ self.open_tag_range: Optional[List[int]] = None
75
+ self.end_tag_range: Optional[List[int]] = None
76
+ self.text_range = [-1,-1]
77
+ self.self_closing_tag_range = [-1,-1]
78
+
79
+ self.idx: int = idx
80
+ self.children: List['TreeNode'] = [] # List of TreeNode instances
81
+
82
+
83
+ def partially_open(self):
84
+ if not self.open_tag: return False
85
+ if any('<' in s for s in self.open_tag) and not any('>' in s for s in self.open_tag):
86
+ return True
87
+ return False
88
+
89
+ def add_child(self,child):
90
+ assert child.parent is None, "Child already has a parent"
91
+ assert child not in self.children, "Child is already in children list"
92
+ child.parent = self
93
+ self.children.append(child)
94
+
95
+ def get_range(self):
96
+ if self.text:
97
+ return list(range(*self.text_range))
98
+ elif self.self_closing_tag:
99
+ return list(range(*self.self_closing_tag_range))
100
+ else:
101
+ attn_range = []
102
+ if self.open_tag_range:
103
+ attn_range += list(range(*self.open_tag_range))
104
+ if self.end_tag_range:
105
+ attn_range += list(range(*self.end_tag_range))
106
+ return attn_range
107
+
108
+ def __repr__(self):
109
+ return f"Node(name='{self.open_tag}', idx = {self.idx})"
110
+
111
+ def print_tree(self, level=0, input_ids = None, tokenizer = None):
112
+ if level == 0:
113
+ print("--------")
114
+ indent = " " * level
115
+ if self.text:
116
+ print(f"{indent}{tokenizer.convert_tokens_to_string(self.text).strip()}, level = {level} ")
117
+ elif self.self_closing_tag:
118
+ print(f"{indent}{tokenizer.convert_tokens_to_string(self.self_closing_tag).strip()}, level = {level} ")
119
+ elif self.open_tag:
120
+ print(f"{indent}{tokenizer.convert_tokens_to_string(self.open_tag).strip()}, level = {level} ")
121
+ for child in self.children:
122
+ child.print_tree(level + 1, input_ids, tokenizer)
123
+ if self.end_tag:
124
+ print(f"{indent}{tokenizer.convert_tokens_to_string(self.end_tag).strip()}, level = {level} ")
125
+ else:
126
+ for child in self.children:
127
+ child.print_tree(level + 1, input_ids, tokenizer)
128
+ if level == 0:
129
+ print("--------")
130
+
131
+ def get_tree(self, level=0, input_ids = None, tokenizer=None):
132
+ tree_str = ""
133
+
134
+ indent = " " * level
135
+ if self.text:
136
+ tree_str+=f"{indent}{tokenizer.convert_tokens_to_string(self.text).strip()} \n"
137
+ elif self.self_closing_tag:
138
+ tree_str+=f"{indent}{tokenizer.convert_tokens_to_string(self.self_closing_tag).strip()} \n"
139
+ elif self.open_tag:
140
+ tree_str+=f"{indent}{tokenizer.convert_tokens_to_string(self.open_tag).strip()} \n"
141
+ for child in self.children:
142
+ tree_str+=child.get_tree(level + 1, input_ids, tokenizer)
143
+ if self.end_tag:
144
+ tree_str+=f"{indent}{tokenizer.convert_tokens_to_string(self.end_tag).strip()} \n"
145
+ else:
146
+ for child in self.children:
147
+ tree_str+=child.get_tree(level + 1, input_ids, tokenizer)
148
+
149
+ return tree_str
150
+
151
+
152
+ class TreeBuilder():
153
+ def __init__(self, tokenizer: AutoTokenizer = None, root: TreeNode = None, cur_node: TreeNode = None):
154
+ self.tokenizer = tokenizer
155
+ self.root = TreeNode(None, 0)
156
+ self.cur_node = self.root
157
+ self.buffer = []
158
+ self.buffer_start_index = 0
159
+ self.idx = 0
160
+ self.full_attention_list= None
161
+ self.web_attention_mask = None
162
+ self.input_ids = None
163
+ self.void_elements = [
164
+ "area",
165
+ "base",
166
+ "br",
167
+ "col",
168
+ "embed",
169
+ "hr",
170
+ "img",
171
+ "input",
172
+ "link",
173
+ "meta",
174
+ "param",
175
+ "source",
176
+ "track",
177
+ "wbr"
178
+ ]
179
+
180
+ def is_empty(self):
181
+ return self.root == None
182
+
183
+ def in_buffer(self, text):
184
+ if len(self.buffer) == 0:
185
+ return False
186
+ return any(text in s for s in self.buffer)
187
+
188
+ def find_buffer(self, text):
189
+ # Iterate over the list of strings with their indices
190
+ for index, s in enumerate(self.buffer):
191
+ if text in s:
192
+ return index
193
+ return -1
194
+
195
+ # Function to extract xxx from <xxx> or <xxx yyy>
196
+ def extract_open_tag_name(self,buffer):
197
+ input_string = self.tokenizer.convert_tokens_to_string(buffer)
198
+ match = re.search(r'<\s*(\w+)(?:\s+[^>]*)?>', input_string)
199
+ if match:
200
+ return match.group(1)
201
+ return None
202
+
203
+ def extract_close_tag_name(self,buffer):
204
+ # if isinstance(input_string, list):
205
+ # input_string = "".join(input_string).replace('Ċ', '\n').replace('Ġ', ' ').replace('ĉ', '\t')
206
+ input_string = self.tokenizer.convert_tokens_to_string(buffer)
207
+ match = re.search(r'</\s*(\w+)(?:\s+[^>]*)?>', input_string)
208
+ if match:
209
+ return match.group(1)
210
+ return None
211
+
212
+ def is_not_empty_buffer(self):
213
+ return self.tokenizer.convert_tokens_to_string(self.buffer).strip() != ''
214
+
215
+ def get_parent_and_siblings_attention_range(self):
216
+ attn_range = []
217
+ if self.cur_node.parent:
218
+ parent = self.cur_node.parent
219
+ if parent.open_tag_range:
220
+ attn_range += list(range(*parent.open_tag_range))
221
+ for child in parent.children:
222
+ if child is not self.cur_node:
223
+ if child.open_tag and child.end_tag:
224
+ attn_range += list(range(*child.open_tag_range))
225
+ attn_range += list(range(*child.end_tag_range))
226
+ elif child.text:
227
+ attn_range += list(range(*child.text_range))
228
+ elif child.self_closing_tag:
229
+ attn_range += list(range(*child.self_closing_tag_range))
230
+ else:
231
+ raise Exception(f"??? line 151, get p and s attention range")
232
+
233
+ return attn_range
234
+
235
+ def update_buffer(self, cur_decoded_token):
236
+ # open tag situations
237
+ assert isinstance(cur_decoded_token,list), f"{cur_decoded_token}"
238
+ self.buffer+=cur_decoded_token
239
+ assert isinstance(cur_decoded_token[0],str)
240
+ # print(self.buffer)
241
+ try:
242
+ # dealing with end tag
243
+ if self.in_buffer('</' ) and self.in_buffer('>') and self.find_buffer('</') <= self.find_buffer('>'):
244
+ close_tag_name = self.extract_close_tag_name(self.buffer)
245
+
246
+ if self.cur_node.open_tag and not self.cur_node.end_tag:
247
+ assert close_tag_name == self.extract_open_tag_name(self.cur_node.open_tag), f"close_tag_name is {close_tag_name}, with buffer: {self.buffer}, open is-----{self.cur_node.open_tag}---"
248
+ elif self.cur_node.text or self.cur_node.self_closing_tag or self.cur_node.end_tag:
249
+ content = None
250
+ if self.cur_node.text: content = self.cur_node.text
251
+ elif self.cur_node.self_closing_tag: content = self.cur_node.self_closing_tag
252
+ elif self.cur_node.end_tag: content = self.cur_node.end_tag
253
+ self.root.print_tree(0,None,self.tokenizer)
254
+ raise Exception(f"This should never happen\n {content}, buffer is {self.buffer}")
255
+
256
+ # assert close_tag_name == extract_open_tag_name(self.cur_node.open_tag), f"close_tag_name is {close_tag_name}, with buffer: {self.buffer}, open is-----{self.cur_node.open_tag}---"
257
+ else:
258
+ raise Exception(f"having end tag without having an open tag\n {self.cur_node.text}")
259
+
260
+ self.cur_node.end_tag = self.buffer[:self.find_buffer('>')+1]
261
+ self.cur_node.end_tag_range = [self.buffer_start_index, self.buffer_start_index + self.find_buffer('>')+1]
262
+ self.buffer_start_index += self.find_buffer('>')+1
263
+ self.buffer = self.buffer[self.find_buffer('>')+1:]
264
+ # dealing with open tag
265
+ elif self.in_buffer('</'):
266
+ if self.cur_node.open_tag and not self.cur_node.end_tag:
267
+ pass
268
+ elif self.cur_node.text or self.cur_node.self_closing_tag or (self.cur_node.open_tag and self.cur_node.end_tag):
269
+ cur_end_tag_index = self.find_buffer('</')
270
+ # import pdb;pdb.set_trace()
271
+ if self.cur_node.text:
272
+ self.cur_node.text += self.buffer[:cur_end_tag_index]
273
+ self.cur_node.text_range[1] += len(self.buffer[:cur_end_tag_index])
274
+ elif self.cur_node.self_closing_tag:
275
+ self.cur_node.self_closing_tag += self.buffer[:cur_end_tag_index]
276
+ self.cur_node.self_closing_tag_range[1] += len(self.buffer[:cur_end_tag_index])
277
+ else:
278
+ self.cur_node.end_tag += self.buffer[:cur_end_tag_index]
279
+ self.cur_node.end_tag_range[1] += len(self.buffer[:cur_end_tag_index])
280
+ self.buffer_start_index += len(self.buffer[:cur_end_tag_index])
281
+ self.buffer =self.buffer[cur_end_tag_index:]
282
+ self.cur_node = self.cur_node.parent
283
+ else:
284
+ raise Exception(f"having end tag without having an open tag\n {self.cur_node.text} {self.cur_node} {self.cur_node.parent.open_tag}")
285
+
286
+ elif self.in_buffer('<') and self.in_buffer('>'):
287
+ # in the case of self_closing tag
288
+ if self.in_buffer('/>'):
289
+ self.cur_node.open_tag = None
290
+ self.cur_node.self_closing_tag = self.buffer[:self.find_buffer(">")+1]
291
+ self.cur_node.self_closing_tag_range = [self.buffer_start_index, self.buffer_start_index + self.find_buffer('>')+1]
292
+ else:
293
+ open_tag_name = self.extract_open_tag_name(self.buffer)
294
+ if open_tag_name in self.void_elements:
295
+ self.cur_node.open_tag = None
296
+ self.cur_node.self_closing_tag = self.buffer[:self.find_buffer(">")+1]
297
+ self.cur_node.self_closing_tag_range = [self.buffer_start_index, self.buffer_start_index + self.find_buffer('>')+1]
298
+ else:
299
+ self.cur_node.open_tag = self.buffer[:self.find_buffer(">")+1]
300
+ self.cur_node.open_tag_range = [self.buffer_start_index, self.buffer_start_index + self.find_buffer('>')+1]
301
+
302
+ self.buffer_start_index += self.find_buffer('>')+1
303
+ self.buffer = self.buffer[self.find_buffer(">")+1:]
304
+ elif self.in_buffer('<'):
305
+ if self.full_attention_list is None:
306
+ self.full_attention_list = self.buffer[:-1]
307
+ self.buffer = self.buffer[-1:]
308
+ self.buffer_start_index = len(self.full_attention_list)
309
+ else:
310
+ cur_open_tag_index = self.find_buffer('<')
311
+ # full open tag, indicating a pair of open and close tags, or a single open tag
312
+ if not self.cur_node.partially_open() and self.cur_node.open_tag:
313
+ if self.cur_node.end_tag:
314
+ self.cur_node.end_tag += self.buffer[:cur_open_tag_index]
315
+ self.cur_node.end_tag_range[1] += len(self.buffer[:cur_open_tag_index])
316
+ self.buffer_start_index += len(self.buffer[:cur_open_tag_index])
317
+ self.buffer =self.buffer[cur_open_tag_index:]
318
+ child_node = TreeNode(self.buffer, self.idx)
319
+ if self.cur_node.parent:
320
+ self.cur_node.parent.add_child(child_node)
321
+ else:
322
+ raise Exception(f"This should never happen, a html element with full open tag should have a parent, {self.cur_node.open_tag}")
323
+ self.idx += 1
324
+ self.cur_node = child_node
325
+ else:
326
+ child_node = TreeNode(self.buffer, self.idx)
327
+ self.cur_node.add_child(child_node)
328
+ self.idx += 1
329
+ self.cur_node = child_node
330
+ elif self.cur_node.text or self.cur_node.self_closing_tag:
331
+ if self.cur_node.text:
332
+ self.cur_node.text += self.buffer[:cur_open_tag_index]
333
+ self.cur_node.text_range[1] += len(self.buffer[:cur_open_tag_index])
334
+ elif self.cur_node.self_closing_tag:
335
+ self.cur_node.self_closing_tag += self.buffer[:cur_open_tag_index]
336
+ self.cur_node.self_closing_tag_range[1] += len(self.buffer[:cur_open_tag_index])
337
+
338
+ self.buffer_start_index += len(self.buffer[:cur_open_tag_index])
339
+ self.buffer =self.buffer[cur_open_tag_index:]
340
+ child_node = TreeNode(self.buffer, self.idx)
341
+ self.cur_node.parent.add_child(child_node)
342
+ self.idx += 1
343
+ self.cur_node = child_node
344
+ # if the current node has an open tag, and we are encountering texts, we create a new text node, and move down a level
345
+ elif (self.cur_node.open_tag or self.cur_node.self_closing_tag) and not self.in_buffer('<') and self.is_not_empty_buffer():
346
+ child_node = TreeNode(None, self.idx)
347
+ child_node.text = self.buffer
348
+ child_node.text_range[0] = self.buffer_start_index
349
+ child_node.text_range[1] = self.buffer_start_index + len(self.buffer)
350
+
351
+ if self.cur_node.end_tag or self.cur_node.self_closing_tag:
352
+ self.cur_node.parent.add_child(child_node)
353
+ else:
354
+ self.cur_node.add_child(child_node)
355
+
356
+ self.idx += 1
357
+ self.cur_node = child_node
358
+ self.buffer_start_index += len(self.buffer)
359
+ self.buffer = []
360
+ # if the current node does not have an open tag, but we are encountering text, we add to the exisitng text node
361
+ elif self.cur_node.text and not self.in_buffer('<') and self.is_not_empty_buffer():
362
+ self.cur_node.text += self.buffer
363
+ assert self.cur_node.text_range[0] != -1 and self.cur_node.text_range[1] != -1, f"self.cur_node.text_range[0] and [1] should not be -1 but: {self.cur_node.text_range[0]}, {self.cur_node.text_range[1]}"
364
+ self.cur_node.text_range[1] += len(self.buffer)
365
+ self.buffer_start_index += len(self.buffer)
366
+ self.buffer =[]
367
+
368
+ except Exception as e:
369
+ traceback.format_exc()
370
+ raise Exception(e)
371
+
372
+ if self.full_attention_list is None:
373
+ attn_range = list(range(len(self.buffer)))
374
+ else:
375
+ attn_range = list(range(len(self.full_attention_list))) + self.get_parent_and_siblings_attention_range() + self.cur_node.get_range() + [i + self.buffer_start_index for i in list(range(len(self.buffer)))]
376
+ return attn_range
model-00001-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7f3ab17e6766272fd3e1a53624c9b428796aeea8c4a917401c1b7c9814135922
3
+ size 4895986336
model-00002-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7c9c4577adcbbfe172eea0ddd138bf858bafbb7d40805290ff0b1033a56ec994
3
+ size 4915916144
model-00003-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8c607a9800a94135d0b2498983d92d63adf64d4e3500310d774bf36a2b230f5a
3
+ size 4915916176
model-00004-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:24c9fa2dcc493160239537f0b227c04ceb208fe2b2710ad62d3f234e5228a769
3
+ size 1688301256
model.safetensors.index.json ADDED
@@ -0,0 +1,803 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 16416014464
4
+ },
5
+ "weight_map": {
6
+ "lm_head.additional_fc.weight": "model-00004-of-00004.safetensors",
7
+ "lm_head.weight": "model-00004-of-00004.safetensors",
8
+ "model.embed_tokens.additional_embedding.weight": "model-00001-of-00004.safetensors",
9
+ "model.embed_tokens.weight": "model-00001-of-00004.safetensors",
10
+ "model.layers.0.input_layernorm.weight": "model-00001-of-00004.safetensors",
11
+ "model.layers.0.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
12
+ "model.layers.0.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
13
+ "model.layers.0.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
14
+ "model.layers.0.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
15
+ "model.layers.0.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
16
+ "model.layers.0.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
17
+ "model.layers.0.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
18
+ "model.layers.0.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
19
+ "model.layers.1.input_layernorm.weight": "model-00001-of-00004.safetensors",
20
+ "model.layers.1.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
21
+ "model.layers.1.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
22
+ "model.layers.1.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
23
+ "model.layers.1.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
24
+ "model.layers.1.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
25
+ "model.layers.1.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
26
+ "model.layers.1.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
27
+ "model.layers.1.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
28
+ "model.layers.10.input_layernorm.weight": "model-00002-of-00004.safetensors",
29
+ "model.layers.10.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
30
+ "model.layers.10.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
31
+ "model.layers.10.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
32
+ "model.layers.10.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
33
+ "model.layers.10.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
34
+ "model.layers.10.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
35
+ "model.layers.10.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
36
+ "model.layers.10.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
37
+ "model.layers.11.input_layernorm.weight": "model-00002-of-00004.safetensors",
38
+ "model.layers.11.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
39
+ "model.layers.11.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
40
+ "model.layers.11.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
41
+ "model.layers.11.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
42
+ "model.layers.11.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
43
+ "model.layers.11.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
44
+ "model.layers.11.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
45
+ "model.layers.11.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
46
+ "model.layers.12.input_layernorm.weight": "model-00002-of-00004.safetensors",
47
+ "model.layers.12.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
48
+ "model.layers.12.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
49
+ "model.layers.12.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
50
+ "model.layers.12.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
51
+ "model.layers.12.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
52
+ "model.layers.12.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
53
+ "model.layers.12.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
54
+ "model.layers.12.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
55
+ "model.layers.13.input_layernorm.weight": "model-00002-of-00004.safetensors",
56
+ "model.layers.13.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
57
+ "model.layers.13.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
58
+ "model.layers.13.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
59
+ "model.layers.13.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
60
+ "model.layers.13.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
61
+ "model.layers.13.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
62
+ "model.layers.13.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
63
+ "model.layers.13.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
64
+ "model.layers.14.input_layernorm.weight": "model-00002-of-00004.safetensors",
65
+ "model.layers.14.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
66
+ "model.layers.14.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
67
+ "model.layers.14.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
68
+ "model.layers.14.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
69
+ "model.layers.14.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
70
+ "model.layers.14.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
71
+ "model.layers.14.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
72
+ "model.layers.14.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
73
+ "model.layers.15.input_layernorm.weight": "model-00002-of-00004.safetensors",
74
+ "model.layers.15.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
75
+ "model.layers.15.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
76
+ "model.layers.15.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
77
+ "model.layers.15.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
78
+ "model.layers.15.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
79
+ "model.layers.15.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
80
+ "model.layers.15.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
81
+ "model.layers.15.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
82
+ "model.layers.16.input_layernorm.weight": "model-00002-of-00004.safetensors",
83
+ "model.layers.16.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
84
+ "model.layers.16.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
85
+ "model.layers.16.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
86
+ "model.layers.16.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
87
+ "model.layers.16.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
88
+ "model.layers.16.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
89
+ "model.layers.16.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
90
+ "model.layers.16.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
91
+ "model.layers.17.input_layernorm.weight": "model-00003-of-00004.safetensors",
92
+ "model.layers.17.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
93
+ "model.layers.17.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
94
+ "model.layers.17.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
95
+ "model.layers.17.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
96
+ "model.layers.17.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
97
+ "model.layers.17.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
98
+ "model.layers.17.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
99
+ "model.layers.17.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
100
+ "model.layers.18.input_layernorm.weight": "model-00003-of-00004.safetensors",
101
+ "model.layers.18.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
102
+ "model.layers.18.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
103
+ "model.layers.18.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
104
+ "model.layers.18.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
105
+ "model.layers.18.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
106
+ "model.layers.18.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
107
+ "model.layers.18.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
108
+ "model.layers.18.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
109
+ "model.layers.19.input_layernorm.weight": "model-00003-of-00004.safetensors",
110
+ "model.layers.19.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
111
+ "model.layers.19.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
112
+ "model.layers.19.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
113
+ "model.layers.19.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
114
+ "model.layers.19.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
115
+ "model.layers.19.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
116
+ "model.layers.19.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
117
+ "model.layers.19.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
118
+ "model.layers.2.input_layernorm.weight": "model-00001-of-00004.safetensors",
119
+ "model.layers.2.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
120
+ "model.layers.2.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
121
+ "model.layers.2.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
122
+ "model.layers.2.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
123
+ "model.layers.2.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
124
+ "model.layers.2.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
125
+ "model.layers.2.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
126
+ "model.layers.2.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
127
+ "model.layers.20.input_layernorm.weight": "model-00003-of-00004.safetensors",
128
+ "model.layers.20.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
129
+ "model.layers.20.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
130
+ "model.layers.20.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
131
+ "model.layers.20.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
132
+ "model.layers.20.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
133
+ "model.layers.20.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
134
+ "model.layers.20.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
135
+ "model.layers.20.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
136
+ "model.layers.21.input_layernorm.weight": "model-00003-of-00004.safetensors",
137
+ "model.layers.21.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
138
+ "model.layers.21.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
139
+ "model.layers.21.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
140
+ "model.layers.21.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
141
+ "model.layers.21.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
142
+ "model.layers.21.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
143
+ "model.layers.21.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
144
+ "model.layers.21.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
145
+ "model.layers.22.input_layernorm.weight": "model-00003-of-00004.safetensors",
146
+ "model.layers.22.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
147
+ "model.layers.22.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
148
+ "model.layers.22.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
149
+ "model.layers.22.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
150
+ "model.layers.22.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
151
+ "model.layers.22.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
152
+ "model.layers.22.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
153
+ "model.layers.22.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
154
+ "model.layers.23.input_layernorm.weight": "model-00003-of-00004.safetensors",
155
+ "model.layers.23.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
156
+ "model.layers.23.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
157
+ "model.layers.23.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
158
+ "model.layers.23.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
159
+ "model.layers.23.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
160
+ "model.layers.23.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
161
+ "model.layers.23.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
162
+ "model.layers.23.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
163
+ "model.layers.24.input_layernorm.weight": "model-00003-of-00004.safetensors",
164
+ "model.layers.24.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
165
+ "model.layers.24.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
166
+ "model.layers.24.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
167
+ "model.layers.24.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
168
+ "model.layers.24.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
169
+ "model.layers.24.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
170
+ "model.layers.24.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
171
+ "model.layers.24.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
172
+ "model.layers.25.input_layernorm.weight": "model-00003-of-00004.safetensors",
173
+ "model.layers.25.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
174
+ "model.layers.25.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
175
+ "model.layers.25.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
176
+ "model.layers.25.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
177
+ "model.layers.25.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
178
+ "model.layers.25.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
179
+ "model.layers.25.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
180
+ "model.layers.25.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
181
+ "model.layers.26.input_layernorm.weight": "model-00003-of-00004.safetensors",
182
+ "model.layers.26.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
183
+ "model.layers.26.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
184
+ "model.layers.26.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
185
+ "model.layers.26.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
186
+ "model.layers.26.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
187
+ "model.layers.26.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
188
+ "model.layers.26.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
189
+ "model.layers.26.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
190
+ "model.layers.27.input_layernorm.weight": "model-00003-of-00004.safetensors",
191
+ "model.layers.27.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
192
+ "model.layers.27.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
193
+ "model.layers.27.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
194
+ "model.layers.27.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
195
+ "model.layers.27.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
196
+ "model.layers.27.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
197
+ "model.layers.27.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
198
+ "model.layers.27.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
199
+ "model.layers.28.input_layernorm.weight": "model-00004-of-00004.safetensors",
200
+ "model.layers.28.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
201
+ "model.layers.28.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
202
+ "model.layers.28.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
203
+ "model.layers.28.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
204
+ "model.layers.28.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
205
+ "model.layers.28.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
206
+ "model.layers.28.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
207
+ "model.layers.28.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
208
+ "model.layers.29.input_layernorm.weight": "model-00004-of-00004.safetensors",
209
+ "model.layers.29.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
210
+ "model.layers.29.mlp.gate_proj.weight": "model-00004-of-00004.safetensors",
211
+ "model.layers.29.mlp.up_proj.weight": "model-00004-of-00004.safetensors",
212
+ "model.layers.29.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
213
+ "model.layers.29.self_attn.k_proj.weight": "model-00004-of-00004.safetensors",
214
+ "model.layers.29.self_attn.o_proj.weight": "model-00004-of-00004.safetensors",
215
+ "model.layers.29.self_attn.q_proj.weight": "model-00004-of-00004.safetensors",
216
+ "model.layers.29.self_attn.v_proj.weight": "model-00004-of-00004.safetensors",
217
+ "model.layers.3.input_layernorm.weight": "model-00001-of-00004.safetensors",
218
+ "model.layers.3.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
219
+ "model.layers.3.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
220
+ "model.layers.3.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
221
+ "model.layers.3.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
222
+ "model.layers.3.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
223
+ "model.layers.3.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
224
+ "model.layers.3.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
225
+ "model.layers.3.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
226
+ "model.layers.30.input_layernorm.weight": "model-00004-of-00004.safetensors",
227
+ "model.layers.30.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
228
+ "model.layers.30.mlp.gate_proj.weight": "model-00004-of-00004.safetensors",
229
+ "model.layers.30.mlp.up_proj.weight": "model-00004-of-00004.safetensors",
230
+ "model.layers.30.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
231
+ "model.layers.30.self_attn.k_proj.weight": "model-00004-of-00004.safetensors",
232
+ "model.layers.30.self_attn.o_proj.weight": "model-00004-of-00004.safetensors",
233
+ "model.layers.30.self_attn.q_proj.weight": "model-00004-of-00004.safetensors",
234
+ "model.layers.30.self_attn.v_proj.weight": "model-00004-of-00004.safetensors",
235
+ "model.layers.31.input_layernorm.weight": "model-00004-of-00004.safetensors",
236
+ "model.layers.31.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
237
+ "model.layers.31.mlp.gate_proj.weight": "model-00004-of-00004.safetensors",
238
+ "model.layers.31.mlp.up_proj.weight": "model-00004-of-00004.safetensors",
239
+ "model.layers.31.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
240
+ "model.layers.31.self_attn.k_proj.weight": "model-00004-of-00004.safetensors",
241
+ "model.layers.31.self_attn.o_proj.weight": "model-00004-of-00004.safetensors",
242
+ "model.layers.31.self_attn.q_proj.weight": "model-00004-of-00004.safetensors",
243
+ "model.layers.31.self_attn.v_proj.weight": "model-00004-of-00004.safetensors",
244
+ "model.layers.4.input_layernorm.weight": "model-00001-of-00004.safetensors",
245
+ "model.layers.4.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
246
+ "model.layers.4.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
247
+ "model.layers.4.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
248
+ "model.layers.4.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
249
+ "model.layers.4.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
250
+ "model.layers.4.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
251
+ "model.layers.4.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
252
+ "model.layers.4.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
253
+ "model.layers.5.input_layernorm.weight": "model-00001-of-00004.safetensors",
254
+ "model.layers.5.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
255
+ "model.layers.5.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
256
+ "model.layers.5.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
257
+ "model.layers.5.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
258
+ "model.layers.5.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
259
+ "model.layers.5.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
260
+ "model.layers.5.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
261
+ "model.layers.5.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
262
+ "model.layers.6.input_layernorm.weight": "model-00002-of-00004.safetensors",
263
+ "model.layers.6.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
264
+ "model.layers.6.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
265
+ "model.layers.6.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
266
+ "model.layers.6.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
267
+ "model.layers.6.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
268
+ "model.layers.6.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
269
+ "model.layers.6.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
270
+ "model.layers.6.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
271
+ "model.layers.7.input_layernorm.weight": "model-00002-of-00004.safetensors",
272
+ "model.layers.7.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
273
+ "model.layers.7.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
274
+ "model.layers.7.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
275
+ "model.layers.7.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
276
+ "model.layers.7.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
277
+ "model.layers.7.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
278
+ "model.layers.7.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
279
+ "model.layers.7.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
280
+ "model.layers.8.input_layernorm.weight": "model-00002-of-00004.safetensors",
281
+ "model.layers.8.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
282
+ "model.layers.8.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
283
+ "model.layers.8.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
284
+ "model.layers.8.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
285
+ "model.layers.8.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
286
+ "model.layers.8.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
287
+ "model.layers.8.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
288
+ "model.layers.8.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
289
+ "model.layers.9.input_layernorm.weight": "model-00002-of-00004.safetensors",
290
+ "model.layers.9.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
291
+ "model.layers.9.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
292
+ "model.layers.9.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
293
+ "model.layers.9.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
294
+ "model.layers.9.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
295
+ "model.layers.9.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
296
+ "model.layers.9.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
297
+ "model.layers.9.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
298
+ "model.modality_projection.act.fc1.weight": "model-00001-of-00004.safetensors",
299
+ "model.modality_projection.act.fc2.weight": "model-00001-of-00004.safetensors",
300
+ "model.modality_projection.fc1.weight": "model-00001-of-00004.safetensors",
301
+ "model.modality_projection.fc2.weight": "model-00001-of-00004.safetensors",
302
+ "model.norm.weight": "model-00004-of-00004.safetensors",
303
+ "model.perceiver_resampler.blocks.0.0.context_layer_norm.bias": "model-00001-of-00004.safetensors",
304
+ "model.perceiver_resampler.blocks.0.0.context_layer_norm.weight": "model-00001-of-00004.safetensors",
305
+ "model.perceiver_resampler.blocks.0.0.k_layer_norm.bias": "model-00001-of-00004.safetensors",
306
+ "model.perceiver_resampler.blocks.0.0.k_layer_norm.weight": "model-00001-of-00004.safetensors",
307
+ "model.perceiver_resampler.blocks.0.0.k_proj.weight": "model-00001-of-00004.safetensors",
308
+ "model.perceiver_resampler.blocks.0.0.latents_layer_norm.bias": "model-00001-of-00004.safetensors",
309
+ "model.perceiver_resampler.blocks.0.0.latents_layer_norm.weight": "model-00001-of-00004.safetensors",
310
+ "model.perceiver_resampler.blocks.0.0.output_proj.weight": "model-00001-of-00004.safetensors",
311
+ "model.perceiver_resampler.blocks.0.0.q_layer_norm.bias": "model-00001-of-00004.safetensors",
312
+ "model.perceiver_resampler.blocks.0.0.q_layer_norm.weight": "model-00001-of-00004.safetensors",
313
+ "model.perceiver_resampler.blocks.0.0.q_proj.weight": "model-00001-of-00004.safetensors",
314
+ "model.perceiver_resampler.blocks.0.0.v_proj.weight": "model-00001-of-00004.safetensors",
315
+ "model.perceiver_resampler.blocks.0.1.c_proj.weight": "model-00001-of-00004.safetensors",
316
+ "model.perceiver_resampler.blocks.0.1.fc.weight": "model-00001-of-00004.safetensors",
317
+ "model.perceiver_resampler.blocks.0.1.ln.bias": "model-00001-of-00004.safetensors",
318
+ "model.perceiver_resampler.blocks.0.1.ln.weight": "model-00001-of-00004.safetensors",
319
+ "model.perceiver_resampler.blocks.1.0.context_layer_norm.bias": "model-00001-of-00004.safetensors",
320
+ "model.perceiver_resampler.blocks.1.0.context_layer_norm.weight": "model-00001-of-00004.safetensors",
321
+ "model.perceiver_resampler.blocks.1.0.k_layer_norm.bias": "model-00001-of-00004.safetensors",
322
+ "model.perceiver_resampler.blocks.1.0.k_layer_norm.weight": "model-00001-of-00004.safetensors",
323
+ "model.perceiver_resampler.blocks.1.0.k_proj.weight": "model-00001-of-00004.safetensors",
324
+ "model.perceiver_resampler.blocks.1.0.latents_layer_norm.bias": "model-00001-of-00004.safetensors",
325
+ "model.perceiver_resampler.blocks.1.0.latents_layer_norm.weight": "model-00001-of-00004.safetensors",
326
+ "model.perceiver_resampler.blocks.1.0.output_proj.weight": "model-00001-of-00004.safetensors",
327
+ "model.perceiver_resampler.blocks.1.0.q_layer_norm.bias": "model-00001-of-00004.safetensors",
328
+ "model.perceiver_resampler.blocks.1.0.q_layer_norm.weight": "model-00001-of-00004.safetensors",
329
+ "model.perceiver_resampler.blocks.1.0.q_proj.weight": "model-00001-of-00004.safetensors",
330
+ "model.perceiver_resampler.blocks.1.0.v_proj.weight": "model-00001-of-00004.safetensors",
331
+ "model.perceiver_resampler.blocks.1.1.c_proj.weight": "model-00001-of-00004.safetensors",
332
+ "model.perceiver_resampler.blocks.1.1.fc.weight": "model-00001-of-00004.safetensors",
333
+ "model.perceiver_resampler.blocks.1.1.ln.bias": "model-00001-of-00004.safetensors",
334
+ "model.perceiver_resampler.blocks.1.1.ln.weight": "model-00001-of-00004.safetensors",
335
+ "model.perceiver_resampler.blocks.2.0.context_layer_norm.bias": "model-00001-of-00004.safetensors",
336
+ "model.perceiver_resampler.blocks.2.0.context_layer_norm.weight": "model-00001-of-00004.safetensors",
337
+ "model.perceiver_resampler.blocks.2.0.k_layer_norm.bias": "model-00001-of-00004.safetensors",
338
+ "model.perceiver_resampler.blocks.2.0.k_layer_norm.weight": "model-00001-of-00004.safetensors",
339
+ "model.perceiver_resampler.blocks.2.0.k_proj.weight": "model-00001-of-00004.safetensors",
340
+ "model.perceiver_resampler.blocks.2.0.latents_layer_norm.bias": "model-00001-of-00004.safetensors",
341
+ "model.perceiver_resampler.blocks.2.0.latents_layer_norm.weight": "model-00001-of-00004.safetensors",
342
+ "model.perceiver_resampler.blocks.2.0.output_proj.weight": "model-00001-of-00004.safetensors",
343
+ "model.perceiver_resampler.blocks.2.0.q_layer_norm.bias": "model-00001-of-00004.safetensors",
344
+ "model.perceiver_resampler.blocks.2.0.q_layer_norm.weight": "model-00001-of-00004.safetensors",
345
+ "model.perceiver_resampler.blocks.2.0.q_proj.weight": "model-00001-of-00004.safetensors",
346
+ "model.perceiver_resampler.blocks.2.0.v_proj.weight": "model-00001-of-00004.safetensors",
347
+ "model.perceiver_resampler.blocks.2.1.c_proj.weight": "model-00001-of-00004.safetensors",
348
+ "model.perceiver_resampler.blocks.2.1.fc.weight": "model-00001-of-00004.safetensors",
349
+ "model.perceiver_resampler.blocks.2.1.ln.bias": "model-00001-of-00004.safetensors",
350
+ "model.perceiver_resampler.blocks.2.1.ln.weight": "model-00001-of-00004.safetensors",
351
+ "model.perceiver_resampler.latents": "model-00001-of-00004.safetensors",
352
+ "model.perceiver_resampler.layer_norm.bias": "model-00001-of-00004.safetensors",
353
+ "model.perceiver_resampler.layer_norm.weight": "model-00001-of-00004.safetensors",
354
+ "model.vision_model.vision_model.embeddings.patch_embedding.bias": "model-00001-of-00004.safetensors",
355
+ "model.vision_model.vision_model.embeddings.patch_embedding.weight": "model-00001-of-00004.safetensors",
356
+ "model.vision_model.vision_model.embeddings.position_embedding.weight": "model-00001-of-00004.safetensors",
357
+ "model.vision_model.vision_model.encoder.layers.0.layer_norm1.bias": "model-00001-of-00004.safetensors",
358
+ "model.vision_model.vision_model.encoder.layers.0.layer_norm1.weight": "model-00001-of-00004.safetensors",
359
+ "model.vision_model.vision_model.encoder.layers.0.layer_norm2.bias": "model-00001-of-00004.safetensors",
360
+ "model.vision_model.vision_model.encoder.layers.0.layer_norm2.weight": "model-00001-of-00004.safetensors",
361
+ "model.vision_model.vision_model.encoder.layers.0.mlp.fc1.bias": "model-00001-of-00004.safetensors",
362
+ "model.vision_model.vision_model.encoder.layers.0.mlp.fc1.weight": "model-00001-of-00004.safetensors",
363
+ "model.vision_model.vision_model.encoder.layers.0.mlp.fc2.bias": "model-00001-of-00004.safetensors",
364
+ "model.vision_model.vision_model.encoder.layers.0.mlp.fc2.weight": "model-00001-of-00004.safetensors",
365
+ "model.vision_model.vision_model.encoder.layers.0.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
366
+ "model.vision_model.vision_model.encoder.layers.0.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
367
+ "model.vision_model.vision_model.encoder.layers.0.self_attn.out_proj.bias": "model-00001-of-00004.safetensors",
368
+ "model.vision_model.vision_model.encoder.layers.0.self_attn.out_proj.weight": "model-00001-of-00004.safetensors",
369
+ "model.vision_model.vision_model.encoder.layers.0.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
370
+ "model.vision_model.vision_model.encoder.layers.0.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
371
+ "model.vision_model.vision_model.encoder.layers.0.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
372
+ "model.vision_model.vision_model.encoder.layers.0.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
373
+ "model.vision_model.vision_model.encoder.layers.1.layer_norm1.bias": "model-00001-of-00004.safetensors",
374
+ "model.vision_model.vision_model.encoder.layers.1.layer_norm1.weight": "model-00001-of-00004.safetensors",
375
+ "model.vision_model.vision_model.encoder.layers.1.layer_norm2.bias": "model-00001-of-00004.safetensors",
376
+ "model.vision_model.vision_model.encoder.layers.1.layer_norm2.weight": "model-00001-of-00004.safetensors",
377
+ "model.vision_model.vision_model.encoder.layers.1.mlp.fc1.bias": "model-00001-of-00004.safetensors",
378
+ "model.vision_model.vision_model.encoder.layers.1.mlp.fc1.weight": "model-00001-of-00004.safetensors",
379
+ "model.vision_model.vision_model.encoder.layers.1.mlp.fc2.bias": "model-00001-of-00004.safetensors",
380
+ "model.vision_model.vision_model.encoder.layers.1.mlp.fc2.weight": "model-00001-of-00004.safetensors",
381
+ "model.vision_model.vision_model.encoder.layers.1.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
382
+ "model.vision_model.vision_model.encoder.layers.1.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
383
+ "model.vision_model.vision_model.encoder.layers.1.self_attn.out_proj.bias": "model-00001-of-00004.safetensors",
384
+ "model.vision_model.vision_model.encoder.layers.1.self_attn.out_proj.weight": "model-00001-of-00004.safetensors",
385
+ "model.vision_model.vision_model.encoder.layers.1.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
386
+ "model.vision_model.vision_model.encoder.layers.1.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
387
+ "model.vision_model.vision_model.encoder.layers.1.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
388
+ "model.vision_model.vision_model.encoder.layers.1.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
389
+ "model.vision_model.vision_model.encoder.layers.10.layer_norm1.bias": "model-00001-of-00004.safetensors",
390
+ "model.vision_model.vision_model.encoder.layers.10.layer_norm1.weight": "model-00001-of-00004.safetensors",
391
+ "model.vision_model.vision_model.encoder.layers.10.layer_norm2.bias": "model-00001-of-00004.safetensors",
392
+ "model.vision_model.vision_model.encoder.layers.10.layer_norm2.weight": "model-00001-of-00004.safetensors",
393
+ "model.vision_model.vision_model.encoder.layers.10.mlp.fc1.bias": "model-00001-of-00004.safetensors",
394
+ "model.vision_model.vision_model.encoder.layers.10.mlp.fc1.weight": "model-00001-of-00004.safetensors",
395
+ "model.vision_model.vision_model.encoder.layers.10.mlp.fc2.bias": "model-00001-of-00004.safetensors",
396
+ "model.vision_model.vision_model.encoder.layers.10.mlp.fc2.weight": "model-00001-of-00004.safetensors",
397
+ "model.vision_model.vision_model.encoder.layers.10.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
398
+ "model.vision_model.vision_model.encoder.layers.10.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
399
+ "model.vision_model.vision_model.encoder.layers.10.self_attn.out_proj.bias": "model-00001-of-00004.safetensors",
400
+ "model.vision_model.vision_model.encoder.layers.10.self_attn.out_proj.weight": "model-00001-of-00004.safetensors",
401
+ "model.vision_model.vision_model.encoder.layers.10.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
402
+ "model.vision_model.vision_model.encoder.layers.10.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
403
+ "model.vision_model.vision_model.encoder.layers.10.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
404
+ "model.vision_model.vision_model.encoder.layers.10.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
405
+ "model.vision_model.vision_model.encoder.layers.11.layer_norm1.bias": "model-00001-of-00004.safetensors",
406
+ "model.vision_model.vision_model.encoder.layers.11.layer_norm1.weight": "model-00001-of-00004.safetensors",
407
+ "model.vision_model.vision_model.encoder.layers.11.layer_norm2.bias": "model-00001-of-00004.safetensors",
408
+ "model.vision_model.vision_model.encoder.layers.11.layer_norm2.weight": "model-00001-of-00004.safetensors",
409
+ "model.vision_model.vision_model.encoder.layers.11.mlp.fc1.bias": "model-00001-of-00004.safetensors",
410
+ "model.vision_model.vision_model.encoder.layers.11.mlp.fc1.weight": "model-00001-of-00004.safetensors",
411
+ "model.vision_model.vision_model.encoder.layers.11.mlp.fc2.bias": "model-00001-of-00004.safetensors",
412
+ "model.vision_model.vision_model.encoder.layers.11.mlp.fc2.weight": "model-00001-of-00004.safetensors",
413
+ "model.vision_model.vision_model.encoder.layers.11.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
414
+ "model.vision_model.vision_model.encoder.layers.11.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
415
+ "model.vision_model.vision_model.encoder.layers.11.self_attn.out_proj.bias": "model-00001-of-00004.safetensors",
416
+ "model.vision_model.vision_model.encoder.layers.11.self_attn.out_proj.weight": "model-00001-of-00004.safetensors",
417
+ "model.vision_model.vision_model.encoder.layers.11.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
418
+ "model.vision_model.vision_model.encoder.layers.11.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
419
+ "model.vision_model.vision_model.encoder.layers.11.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
420
+ "model.vision_model.vision_model.encoder.layers.11.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
421
+ "model.vision_model.vision_model.encoder.layers.12.layer_norm1.bias": "model-00001-of-00004.safetensors",
422
+ "model.vision_model.vision_model.encoder.layers.12.layer_norm1.weight": "model-00001-of-00004.safetensors",
423
+ "model.vision_model.vision_model.encoder.layers.12.layer_norm2.bias": "model-00001-of-00004.safetensors",
424
+ "model.vision_model.vision_model.encoder.layers.12.layer_norm2.weight": "model-00001-of-00004.safetensors",
425
+ "model.vision_model.vision_model.encoder.layers.12.mlp.fc1.bias": "model-00001-of-00004.safetensors",
426
+ "model.vision_model.vision_model.encoder.layers.12.mlp.fc1.weight": "model-00001-of-00004.safetensors",
427
+ "model.vision_model.vision_model.encoder.layers.12.mlp.fc2.bias": "model-00001-of-00004.safetensors",
428
+ "model.vision_model.vision_model.encoder.layers.12.mlp.fc2.weight": "model-00001-of-00004.safetensors",
429
+ "model.vision_model.vision_model.encoder.layers.12.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
430
+ "model.vision_model.vision_model.encoder.layers.12.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
431
+ "model.vision_model.vision_model.encoder.layers.12.self_attn.out_proj.bias": "model-00001-of-00004.safetensors",
432
+ "model.vision_model.vision_model.encoder.layers.12.self_attn.out_proj.weight": "model-00001-of-00004.safetensors",
433
+ "model.vision_model.vision_model.encoder.layers.12.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
434
+ "model.vision_model.vision_model.encoder.layers.12.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
435
+ "model.vision_model.vision_model.encoder.layers.12.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
436
+ "model.vision_model.vision_model.encoder.layers.12.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
437
+ "model.vision_model.vision_model.encoder.layers.13.layer_norm1.bias": "model-00001-of-00004.safetensors",
438
+ "model.vision_model.vision_model.encoder.layers.13.layer_norm1.weight": "model-00001-of-00004.safetensors",
439
+ "model.vision_model.vision_model.encoder.layers.13.layer_norm2.bias": "model-00001-of-00004.safetensors",
440
+ "model.vision_model.vision_model.encoder.layers.13.layer_norm2.weight": "model-00001-of-00004.safetensors",
441
+ "model.vision_model.vision_model.encoder.layers.13.mlp.fc1.bias": "model-00001-of-00004.safetensors",
442
+ "model.vision_model.vision_model.encoder.layers.13.mlp.fc1.weight": "model-00001-of-00004.safetensors",
443
+ "model.vision_model.vision_model.encoder.layers.13.mlp.fc2.bias": "model-00001-of-00004.safetensors",
444
+ "model.vision_model.vision_model.encoder.layers.13.mlp.fc2.weight": "model-00001-of-00004.safetensors",
445
+ "model.vision_model.vision_model.encoder.layers.13.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
446
+ "model.vision_model.vision_model.encoder.layers.13.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
447
+ "model.vision_model.vision_model.encoder.layers.13.self_attn.out_proj.bias": "model-00001-of-00004.safetensors",
448
+ "model.vision_model.vision_model.encoder.layers.13.self_attn.out_proj.weight": "model-00001-of-00004.safetensors",
449
+ "model.vision_model.vision_model.encoder.layers.13.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
450
+ "model.vision_model.vision_model.encoder.layers.13.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
451
+ "model.vision_model.vision_model.encoder.layers.13.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
452
+ "model.vision_model.vision_model.encoder.layers.13.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
453
+ "model.vision_model.vision_model.encoder.layers.14.layer_norm1.bias": "model-00001-of-00004.safetensors",
454
+ "model.vision_model.vision_model.encoder.layers.14.layer_norm1.weight": "model-00001-of-00004.safetensors",
455
+ "model.vision_model.vision_model.encoder.layers.14.layer_norm2.bias": "model-00001-of-00004.safetensors",
456
+ "model.vision_model.vision_model.encoder.layers.14.layer_norm2.weight": "model-00001-of-00004.safetensors",
457
+ "model.vision_model.vision_model.encoder.layers.14.mlp.fc1.bias": "model-00001-of-00004.safetensors",
458
+ "model.vision_model.vision_model.encoder.layers.14.mlp.fc1.weight": "model-00001-of-00004.safetensors",
459
+ "model.vision_model.vision_model.encoder.layers.14.mlp.fc2.bias": "model-00001-of-00004.safetensors",
460
+ "model.vision_model.vision_model.encoder.layers.14.mlp.fc2.weight": "model-00001-of-00004.safetensors",
461
+ "model.vision_model.vision_model.encoder.layers.14.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
462
+ "model.vision_model.vision_model.encoder.layers.14.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
463
+ "model.vision_model.vision_model.encoder.layers.14.self_attn.out_proj.bias": "model-00001-of-00004.safetensors",
464
+ "model.vision_model.vision_model.encoder.layers.14.self_attn.out_proj.weight": "model-00001-of-00004.safetensors",
465
+ "model.vision_model.vision_model.encoder.layers.14.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
466
+ "model.vision_model.vision_model.encoder.layers.14.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
467
+ "model.vision_model.vision_model.encoder.layers.14.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
468
+ "model.vision_model.vision_model.encoder.layers.14.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
469
+ "model.vision_model.vision_model.encoder.layers.15.layer_norm1.bias": "model-00001-of-00004.safetensors",
470
+ "model.vision_model.vision_model.encoder.layers.15.layer_norm1.weight": "model-00001-of-00004.safetensors",
471
+ "model.vision_model.vision_model.encoder.layers.15.layer_norm2.bias": "model-00001-of-00004.safetensors",
472
+ "model.vision_model.vision_model.encoder.layers.15.layer_norm2.weight": "model-00001-of-00004.safetensors",
473
+ "model.vision_model.vision_model.encoder.layers.15.mlp.fc1.bias": "model-00001-of-00004.safetensors",
474
+ "model.vision_model.vision_model.encoder.layers.15.mlp.fc1.weight": "model-00001-of-00004.safetensors",
475
+ "model.vision_model.vision_model.encoder.layers.15.mlp.fc2.bias": "model-00001-of-00004.safetensors",
476
+ "model.vision_model.vision_model.encoder.layers.15.mlp.fc2.weight": "model-00001-of-00004.safetensors",
477
+ "model.vision_model.vision_model.encoder.layers.15.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
478
+ "model.vision_model.vision_model.encoder.layers.15.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
479
+ "model.vision_model.vision_model.encoder.layers.15.self_attn.out_proj.bias": "model-00001-of-00004.safetensors",
480
+ "model.vision_model.vision_model.encoder.layers.15.self_attn.out_proj.weight": "model-00001-of-00004.safetensors",
481
+ "model.vision_model.vision_model.encoder.layers.15.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
482
+ "model.vision_model.vision_model.encoder.layers.15.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
483
+ "model.vision_model.vision_model.encoder.layers.15.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
484
+ "model.vision_model.vision_model.encoder.layers.15.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
485
+ "model.vision_model.vision_model.encoder.layers.16.layer_norm1.bias": "model-00001-of-00004.safetensors",
486
+ "model.vision_model.vision_model.encoder.layers.16.layer_norm1.weight": "model-00001-of-00004.safetensors",
487
+ "model.vision_model.vision_model.encoder.layers.16.layer_norm2.bias": "model-00001-of-00004.safetensors",
488
+ "model.vision_model.vision_model.encoder.layers.16.layer_norm2.weight": "model-00001-of-00004.safetensors",
489
+ "model.vision_model.vision_model.encoder.layers.16.mlp.fc1.bias": "model-00001-of-00004.safetensors",
490
+ "model.vision_model.vision_model.encoder.layers.16.mlp.fc1.weight": "model-00001-of-00004.safetensors",
491
+ "model.vision_model.vision_model.encoder.layers.16.mlp.fc2.bias": "model-00001-of-00004.safetensors",
492
+ "model.vision_model.vision_model.encoder.layers.16.mlp.fc2.weight": "model-00001-of-00004.safetensors",
493
+ "model.vision_model.vision_model.encoder.layers.16.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
494
+ "model.vision_model.vision_model.encoder.layers.16.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
495
+ "model.vision_model.vision_model.encoder.layers.16.self_attn.out_proj.bias": "model-00001-of-00004.safetensors",
496
+ "model.vision_model.vision_model.encoder.layers.16.self_attn.out_proj.weight": "model-00001-of-00004.safetensors",
497
+ "model.vision_model.vision_model.encoder.layers.16.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
498
+ "model.vision_model.vision_model.encoder.layers.16.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
499
+ "model.vision_model.vision_model.encoder.layers.16.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
500
+ "model.vision_model.vision_model.encoder.layers.16.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
501
+ "model.vision_model.vision_model.encoder.layers.17.layer_norm1.bias": "model-00001-of-00004.safetensors",
502
+ "model.vision_model.vision_model.encoder.layers.17.layer_norm1.weight": "model-00001-of-00004.safetensors",
503
+ "model.vision_model.vision_model.encoder.layers.17.layer_norm2.bias": "model-00001-of-00004.safetensors",
504
+ "model.vision_model.vision_model.encoder.layers.17.layer_norm2.weight": "model-00001-of-00004.safetensors",
505
+ "model.vision_model.vision_model.encoder.layers.17.mlp.fc1.bias": "model-00001-of-00004.safetensors",
506
+ "model.vision_model.vision_model.encoder.layers.17.mlp.fc1.weight": "model-00001-of-00004.safetensors",
507
+ "model.vision_model.vision_model.encoder.layers.17.mlp.fc2.bias": "model-00001-of-00004.safetensors",
508
+ "model.vision_model.vision_model.encoder.layers.17.mlp.fc2.weight": "model-00001-of-00004.safetensors",
509
+ "model.vision_model.vision_model.encoder.layers.17.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
510
+ "model.vision_model.vision_model.encoder.layers.17.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
511
+ "model.vision_model.vision_model.encoder.layers.17.self_attn.out_proj.bias": "model-00001-of-00004.safetensors",
512
+ "model.vision_model.vision_model.encoder.layers.17.self_attn.out_proj.weight": "model-00001-of-00004.safetensors",
513
+ "model.vision_model.vision_model.encoder.layers.17.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
514
+ "model.vision_model.vision_model.encoder.layers.17.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
515
+ "model.vision_model.vision_model.encoder.layers.17.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
516
+ "model.vision_model.vision_model.encoder.layers.17.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
517
+ "model.vision_model.vision_model.encoder.layers.18.layer_norm1.bias": "model-00001-of-00004.safetensors",
518
+ "model.vision_model.vision_model.encoder.layers.18.layer_norm1.weight": "model-00001-of-00004.safetensors",
519
+ "model.vision_model.vision_model.encoder.layers.18.layer_norm2.bias": "model-00001-of-00004.safetensors",
520
+ "model.vision_model.vision_model.encoder.layers.18.layer_norm2.weight": "model-00001-of-00004.safetensors",
521
+ "model.vision_model.vision_model.encoder.layers.18.mlp.fc1.bias": "model-00001-of-00004.safetensors",
522
+ "model.vision_model.vision_model.encoder.layers.18.mlp.fc1.weight": "model-00001-of-00004.safetensors",
523
+ "model.vision_model.vision_model.encoder.layers.18.mlp.fc2.bias": "model-00001-of-00004.safetensors",
524
+ "model.vision_model.vision_model.encoder.layers.18.mlp.fc2.weight": "model-00001-of-00004.safetensors",
525
+ "model.vision_model.vision_model.encoder.layers.18.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
526
+ "model.vision_model.vision_model.encoder.layers.18.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
527
+ "model.vision_model.vision_model.encoder.layers.18.self_attn.out_proj.bias": "model-00001-of-00004.safetensors",
528
+ "model.vision_model.vision_model.encoder.layers.18.self_attn.out_proj.weight": "model-00001-of-00004.safetensors",
529
+ "model.vision_model.vision_model.encoder.layers.18.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
530
+ "model.vision_model.vision_model.encoder.layers.18.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
531
+ "model.vision_model.vision_model.encoder.layers.18.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
532
+ "model.vision_model.vision_model.encoder.layers.18.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
533
+ "model.vision_model.vision_model.encoder.layers.19.layer_norm1.bias": "model-00001-of-00004.safetensors",
534
+ "model.vision_model.vision_model.encoder.layers.19.layer_norm1.weight": "model-00001-of-00004.safetensors",
535
+ "model.vision_model.vision_model.encoder.layers.19.layer_norm2.bias": "model-00001-of-00004.safetensors",
536
+ "model.vision_model.vision_model.encoder.layers.19.layer_norm2.weight": "model-00001-of-00004.safetensors",
537
+ "model.vision_model.vision_model.encoder.layers.19.mlp.fc1.bias": "model-00001-of-00004.safetensors",
538
+ "model.vision_model.vision_model.encoder.layers.19.mlp.fc1.weight": "model-00001-of-00004.safetensors",
539
+ "model.vision_model.vision_model.encoder.layers.19.mlp.fc2.bias": "model-00001-of-00004.safetensors",
540
+ "model.vision_model.vision_model.encoder.layers.19.mlp.fc2.weight": "model-00001-of-00004.safetensors",
541
+ "model.vision_model.vision_model.encoder.layers.19.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
542
+ "model.vision_model.vision_model.encoder.layers.19.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
543
+ "model.vision_model.vision_model.encoder.layers.19.self_attn.out_proj.bias": "model-00001-of-00004.safetensors",
544
+ "model.vision_model.vision_model.encoder.layers.19.self_attn.out_proj.weight": "model-00001-of-00004.safetensors",
545
+ "model.vision_model.vision_model.encoder.layers.19.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
546
+ "model.vision_model.vision_model.encoder.layers.19.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
547
+ "model.vision_model.vision_model.encoder.layers.19.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
548
+ "model.vision_model.vision_model.encoder.layers.19.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
549
+ "model.vision_model.vision_model.encoder.layers.2.layer_norm1.bias": "model-00001-of-00004.safetensors",
550
+ "model.vision_model.vision_model.encoder.layers.2.layer_norm1.weight": "model-00001-of-00004.safetensors",
551
+ "model.vision_model.vision_model.encoder.layers.2.layer_norm2.bias": "model-00001-of-00004.safetensors",
552
+ "model.vision_model.vision_model.encoder.layers.2.layer_norm2.weight": "model-00001-of-00004.safetensors",
553
+ "model.vision_model.vision_model.encoder.layers.2.mlp.fc1.bias": "model-00001-of-00004.safetensors",
554
+ "model.vision_model.vision_model.encoder.layers.2.mlp.fc1.weight": "model-00001-of-00004.safetensors",
555
+ "model.vision_model.vision_model.encoder.layers.2.mlp.fc2.bias": "model-00001-of-00004.safetensors",
556
+ "model.vision_model.vision_model.encoder.layers.2.mlp.fc2.weight": "model-00001-of-00004.safetensors",
557
+ "model.vision_model.vision_model.encoder.layers.2.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
558
+ "model.vision_model.vision_model.encoder.layers.2.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
559
+ "model.vision_model.vision_model.encoder.layers.2.self_attn.out_proj.bias": "model-00001-of-00004.safetensors",
560
+ "model.vision_model.vision_model.encoder.layers.2.self_attn.out_proj.weight": "model-00001-of-00004.safetensors",
561
+ "model.vision_model.vision_model.encoder.layers.2.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
562
+ "model.vision_model.vision_model.encoder.layers.2.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
563
+ "model.vision_model.vision_model.encoder.layers.2.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
564
+ "model.vision_model.vision_model.encoder.layers.2.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
565
+ "model.vision_model.vision_model.encoder.layers.20.layer_norm1.bias": "model-00001-of-00004.safetensors",
566
+ "model.vision_model.vision_model.encoder.layers.20.layer_norm1.weight": "model-00001-of-00004.safetensors",
567
+ "model.vision_model.vision_model.encoder.layers.20.layer_norm2.bias": "model-00001-of-00004.safetensors",
568
+ "model.vision_model.vision_model.encoder.layers.20.layer_norm2.weight": "model-00001-of-00004.safetensors",
569
+ "model.vision_model.vision_model.encoder.layers.20.mlp.fc1.bias": "model-00001-of-00004.safetensors",
570
+ "model.vision_model.vision_model.encoder.layers.20.mlp.fc1.weight": "model-00001-of-00004.safetensors",
571
+ "model.vision_model.vision_model.encoder.layers.20.mlp.fc2.bias": "model-00001-of-00004.safetensors",
572
+ "model.vision_model.vision_model.encoder.layers.20.mlp.fc2.weight": "model-00001-of-00004.safetensors",
573
+ "model.vision_model.vision_model.encoder.layers.20.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
574
+ "model.vision_model.vision_model.encoder.layers.20.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
575
+ "model.vision_model.vision_model.encoder.layers.20.self_attn.out_proj.bias": "model-00001-of-00004.safetensors",
576
+ "model.vision_model.vision_model.encoder.layers.20.self_attn.out_proj.weight": "model-00001-of-00004.safetensors",
577
+ "model.vision_model.vision_model.encoder.layers.20.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
578
+ "model.vision_model.vision_model.encoder.layers.20.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
579
+ "model.vision_model.vision_model.encoder.layers.20.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
580
+ "model.vision_model.vision_model.encoder.layers.20.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
581
+ "model.vision_model.vision_model.encoder.layers.21.layer_norm1.bias": "model-00001-of-00004.safetensors",
582
+ "model.vision_model.vision_model.encoder.layers.21.layer_norm1.weight": "model-00001-of-00004.safetensors",
583
+ "model.vision_model.vision_model.encoder.layers.21.layer_norm2.bias": "model-00001-of-00004.safetensors",
584
+ "model.vision_model.vision_model.encoder.layers.21.layer_norm2.weight": "model-00001-of-00004.safetensors",
585
+ "model.vision_model.vision_model.encoder.layers.21.mlp.fc1.bias": "model-00001-of-00004.safetensors",
586
+ "model.vision_model.vision_model.encoder.layers.21.mlp.fc1.weight": "model-00001-of-00004.safetensors",
587
+ "model.vision_model.vision_model.encoder.layers.21.mlp.fc2.bias": "model-00001-of-00004.safetensors",
588
+ "model.vision_model.vision_model.encoder.layers.21.mlp.fc2.weight": "model-00001-of-00004.safetensors",
589
+ "model.vision_model.vision_model.encoder.layers.21.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
590
+ "model.vision_model.vision_model.encoder.layers.21.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
591
+ "model.vision_model.vision_model.encoder.layers.21.self_attn.out_proj.bias": "model-00001-of-00004.safetensors",
592
+ "model.vision_model.vision_model.encoder.layers.21.self_attn.out_proj.weight": "model-00001-of-00004.safetensors",
593
+ "model.vision_model.vision_model.encoder.layers.21.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
594
+ "model.vision_model.vision_model.encoder.layers.21.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
595
+ "model.vision_model.vision_model.encoder.layers.21.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
596
+ "model.vision_model.vision_model.encoder.layers.21.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
597
+ "model.vision_model.vision_model.encoder.layers.22.layer_norm1.bias": "model-00001-of-00004.safetensors",
598
+ "model.vision_model.vision_model.encoder.layers.22.layer_norm1.weight": "model-00001-of-00004.safetensors",
599
+ "model.vision_model.vision_model.encoder.layers.22.layer_norm2.bias": "model-00001-of-00004.safetensors",
600
+ "model.vision_model.vision_model.encoder.layers.22.layer_norm2.weight": "model-00001-of-00004.safetensors",
601
+ "model.vision_model.vision_model.encoder.layers.22.mlp.fc1.bias": "model-00001-of-00004.safetensors",
602
+ "model.vision_model.vision_model.encoder.layers.22.mlp.fc1.weight": "model-00001-of-00004.safetensors",
603
+ "model.vision_model.vision_model.encoder.layers.22.mlp.fc2.bias": "model-00001-of-00004.safetensors",
604
+ "model.vision_model.vision_model.encoder.layers.22.mlp.fc2.weight": "model-00001-of-00004.safetensors",
605
+ "model.vision_model.vision_model.encoder.layers.22.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
606
+ "model.vision_model.vision_model.encoder.layers.22.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
607
+ "model.vision_model.vision_model.encoder.layers.22.self_attn.out_proj.bias": "model-00001-of-00004.safetensors",
608
+ "model.vision_model.vision_model.encoder.layers.22.self_attn.out_proj.weight": "model-00001-of-00004.safetensors",
609
+ "model.vision_model.vision_model.encoder.layers.22.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
610
+ "model.vision_model.vision_model.encoder.layers.22.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
611
+ "model.vision_model.vision_model.encoder.layers.22.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
612
+ "model.vision_model.vision_model.encoder.layers.22.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
613
+ "model.vision_model.vision_model.encoder.layers.23.layer_norm1.bias": "model-00001-of-00004.safetensors",
614
+ "model.vision_model.vision_model.encoder.layers.23.layer_norm1.weight": "model-00001-of-00004.safetensors",
615
+ "model.vision_model.vision_model.encoder.layers.23.layer_norm2.bias": "model-00001-of-00004.safetensors",
616
+ "model.vision_model.vision_model.encoder.layers.23.layer_norm2.weight": "model-00001-of-00004.safetensors",
617
+ "model.vision_model.vision_model.encoder.layers.23.mlp.fc1.bias": "model-00001-of-00004.safetensors",
618
+ "model.vision_model.vision_model.encoder.layers.23.mlp.fc1.weight": "model-00001-of-00004.safetensors",
619
+ "model.vision_model.vision_model.encoder.layers.23.mlp.fc2.bias": "model-00001-of-00004.safetensors",
620
+ "model.vision_model.vision_model.encoder.layers.23.mlp.fc2.weight": "model-00001-of-00004.safetensors",
621
+ "model.vision_model.vision_model.encoder.layers.23.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
622
+ "model.vision_model.vision_model.encoder.layers.23.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
623
+ "model.vision_model.vision_model.encoder.layers.23.self_attn.out_proj.bias": "model-00001-of-00004.safetensors",
624
+ "model.vision_model.vision_model.encoder.layers.23.self_attn.out_proj.weight": "model-00001-of-00004.safetensors",
625
+ "model.vision_model.vision_model.encoder.layers.23.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
626
+ "model.vision_model.vision_model.encoder.layers.23.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
627
+ "model.vision_model.vision_model.encoder.layers.23.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
628
+ "model.vision_model.vision_model.encoder.layers.23.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
629
+ "model.vision_model.vision_model.encoder.layers.24.layer_norm1.bias": "model-00001-of-00004.safetensors",
630
+ "model.vision_model.vision_model.encoder.layers.24.layer_norm1.weight": "model-00001-of-00004.safetensors",
631
+ "model.vision_model.vision_model.encoder.layers.24.layer_norm2.bias": "model-00001-of-00004.safetensors",
632
+ "model.vision_model.vision_model.encoder.layers.24.layer_norm2.weight": "model-00001-of-00004.safetensors",
633
+ "model.vision_model.vision_model.encoder.layers.24.mlp.fc1.bias": "model-00001-of-00004.safetensors",
634
+ "model.vision_model.vision_model.encoder.layers.24.mlp.fc1.weight": "model-00001-of-00004.safetensors",
635
+ "model.vision_model.vision_model.encoder.layers.24.mlp.fc2.bias": "model-00001-of-00004.safetensors",
636
+ "model.vision_model.vision_model.encoder.layers.24.mlp.fc2.weight": "model-00001-of-00004.safetensors",
637
+ "model.vision_model.vision_model.encoder.layers.24.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
638
+ "model.vision_model.vision_model.encoder.layers.24.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
639
+ "model.vision_model.vision_model.encoder.layers.24.self_attn.out_proj.bias": "model-00001-of-00004.safetensors",
640
+ "model.vision_model.vision_model.encoder.layers.24.self_attn.out_proj.weight": "model-00001-of-00004.safetensors",
641
+ "model.vision_model.vision_model.encoder.layers.24.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
642
+ "model.vision_model.vision_model.encoder.layers.24.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
643
+ "model.vision_model.vision_model.encoder.layers.24.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
644
+ "model.vision_model.vision_model.encoder.layers.24.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
645
+ "model.vision_model.vision_model.encoder.layers.25.layer_norm1.bias": "model-00001-of-00004.safetensors",
646
+ "model.vision_model.vision_model.encoder.layers.25.layer_norm1.weight": "model-00001-of-00004.safetensors",
647
+ "model.vision_model.vision_model.encoder.layers.25.layer_norm2.bias": "model-00001-of-00004.safetensors",
648
+ "model.vision_model.vision_model.encoder.layers.25.layer_norm2.weight": "model-00001-of-00004.safetensors",
649
+ "model.vision_model.vision_model.encoder.layers.25.mlp.fc1.bias": "model-00001-of-00004.safetensors",
650
+ "model.vision_model.vision_model.encoder.layers.25.mlp.fc1.weight": "model-00001-of-00004.safetensors",
651
+ "model.vision_model.vision_model.encoder.layers.25.mlp.fc2.bias": "model-00001-of-00004.safetensors",
652
+ "model.vision_model.vision_model.encoder.layers.25.mlp.fc2.weight": "model-00001-of-00004.safetensors",
653
+ "model.vision_model.vision_model.encoder.layers.25.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
654
+ "model.vision_model.vision_model.encoder.layers.25.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
655
+ "model.vision_model.vision_model.encoder.layers.25.self_attn.out_proj.bias": "model-00001-of-00004.safetensors",
656
+ "model.vision_model.vision_model.encoder.layers.25.self_attn.out_proj.weight": "model-00001-of-00004.safetensors",
657
+ "model.vision_model.vision_model.encoder.layers.25.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
658
+ "model.vision_model.vision_model.encoder.layers.25.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
659
+ "model.vision_model.vision_model.encoder.layers.25.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
660
+ "model.vision_model.vision_model.encoder.layers.25.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
661
+ "model.vision_model.vision_model.encoder.layers.26.layer_norm1.bias": "model-00001-of-00004.safetensors",
662
+ "model.vision_model.vision_model.encoder.layers.26.layer_norm1.weight": "model-00001-of-00004.safetensors",
663
+ "model.vision_model.vision_model.encoder.layers.26.layer_norm2.bias": "model-00001-of-00004.safetensors",
664
+ "model.vision_model.vision_model.encoder.layers.26.layer_norm2.weight": "model-00001-of-00004.safetensors",
665
+ "model.vision_model.vision_model.encoder.layers.26.mlp.fc1.bias": "model-00001-of-00004.safetensors",
666
+ "model.vision_model.vision_model.encoder.layers.26.mlp.fc1.weight": "model-00001-of-00004.safetensors",
667
+ "model.vision_model.vision_model.encoder.layers.26.mlp.fc2.bias": "model-00001-of-00004.safetensors",
668
+ "model.vision_model.vision_model.encoder.layers.26.mlp.fc2.weight": "model-00001-of-00004.safetensors",
669
+ "model.vision_model.vision_model.encoder.layers.26.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
670
+ "model.vision_model.vision_model.encoder.layers.26.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
671
+ "model.vision_model.vision_model.encoder.layers.26.self_attn.out_proj.bias": "model-00001-of-00004.safetensors",
672
+ "model.vision_model.vision_model.encoder.layers.26.self_attn.out_proj.weight": "model-00001-of-00004.safetensors",
673
+ "model.vision_model.vision_model.encoder.layers.26.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
674
+ "model.vision_model.vision_model.encoder.layers.26.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
675
+ "model.vision_model.vision_model.encoder.layers.26.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
676
+ "model.vision_model.vision_model.encoder.layers.26.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
677
+ "model.vision_model.vision_model.encoder.layers.3.layer_norm1.bias": "model-00001-of-00004.safetensors",
678
+ "model.vision_model.vision_model.encoder.layers.3.layer_norm1.weight": "model-00001-of-00004.safetensors",
679
+ "model.vision_model.vision_model.encoder.layers.3.layer_norm2.bias": "model-00001-of-00004.safetensors",
680
+ "model.vision_model.vision_model.encoder.layers.3.layer_norm2.weight": "model-00001-of-00004.safetensors",
681
+ "model.vision_model.vision_model.encoder.layers.3.mlp.fc1.bias": "model-00001-of-00004.safetensors",
682
+ "model.vision_model.vision_model.encoder.layers.3.mlp.fc1.weight": "model-00001-of-00004.safetensors",
683
+ "model.vision_model.vision_model.encoder.layers.3.mlp.fc2.bias": "model-00001-of-00004.safetensors",
684
+ "model.vision_model.vision_model.encoder.layers.3.mlp.fc2.weight": "model-00001-of-00004.safetensors",
685
+ "model.vision_model.vision_model.encoder.layers.3.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
686
+ "model.vision_model.vision_model.encoder.layers.3.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
687
+ "model.vision_model.vision_model.encoder.layers.3.self_attn.out_proj.bias": "model-00001-of-00004.safetensors",
688
+ "model.vision_model.vision_model.encoder.layers.3.self_attn.out_proj.weight": "model-00001-of-00004.safetensors",
689
+ "model.vision_model.vision_model.encoder.layers.3.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
690
+ "model.vision_model.vision_model.encoder.layers.3.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
691
+ "model.vision_model.vision_model.encoder.layers.3.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
692
+ "model.vision_model.vision_model.encoder.layers.3.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
693
+ "model.vision_model.vision_model.encoder.layers.4.layer_norm1.bias": "model-00001-of-00004.safetensors",
694
+ "model.vision_model.vision_model.encoder.layers.4.layer_norm1.weight": "model-00001-of-00004.safetensors",
695
+ "model.vision_model.vision_model.encoder.layers.4.layer_norm2.bias": "model-00001-of-00004.safetensors",
696
+ "model.vision_model.vision_model.encoder.layers.4.layer_norm2.weight": "model-00001-of-00004.safetensors",
697
+ "model.vision_model.vision_model.encoder.layers.4.mlp.fc1.bias": "model-00001-of-00004.safetensors",
698
+ "model.vision_model.vision_model.encoder.layers.4.mlp.fc1.weight": "model-00001-of-00004.safetensors",
699
+ "model.vision_model.vision_model.encoder.layers.4.mlp.fc2.bias": "model-00001-of-00004.safetensors",
700
+ "model.vision_model.vision_model.encoder.layers.4.mlp.fc2.weight": "model-00001-of-00004.safetensors",
701
+ "model.vision_model.vision_model.encoder.layers.4.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
702
+ "model.vision_model.vision_model.encoder.layers.4.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
703
+ "model.vision_model.vision_model.encoder.layers.4.self_attn.out_proj.bias": "model-00001-of-00004.safetensors",
704
+ "model.vision_model.vision_model.encoder.layers.4.self_attn.out_proj.weight": "model-00001-of-00004.safetensors",
705
+ "model.vision_model.vision_model.encoder.layers.4.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
706
+ "model.vision_model.vision_model.encoder.layers.4.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
707
+ "model.vision_model.vision_model.encoder.layers.4.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
708
+ "model.vision_model.vision_model.encoder.layers.4.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
709
+ "model.vision_model.vision_model.encoder.layers.5.layer_norm1.bias": "model-00001-of-00004.safetensors",
710
+ "model.vision_model.vision_model.encoder.layers.5.layer_norm1.weight": "model-00001-of-00004.safetensors",
711
+ "model.vision_model.vision_model.encoder.layers.5.layer_norm2.bias": "model-00001-of-00004.safetensors",
712
+ "model.vision_model.vision_model.encoder.layers.5.layer_norm2.weight": "model-00001-of-00004.safetensors",
713
+ "model.vision_model.vision_model.encoder.layers.5.mlp.fc1.bias": "model-00001-of-00004.safetensors",
714
+ "model.vision_model.vision_model.encoder.layers.5.mlp.fc1.weight": "model-00001-of-00004.safetensors",
715
+ "model.vision_model.vision_model.encoder.layers.5.mlp.fc2.bias": "model-00001-of-00004.safetensors",
716
+ "model.vision_model.vision_model.encoder.layers.5.mlp.fc2.weight": "model-00001-of-00004.safetensors",
717
+ "model.vision_model.vision_model.encoder.layers.5.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
718
+ "model.vision_model.vision_model.encoder.layers.5.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
719
+ "model.vision_model.vision_model.encoder.layers.5.self_attn.out_proj.bias": "model-00001-of-00004.safetensors",
720
+ "model.vision_model.vision_model.encoder.layers.5.self_attn.out_proj.weight": "model-00001-of-00004.safetensors",
721
+ "model.vision_model.vision_model.encoder.layers.5.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
722
+ "model.vision_model.vision_model.encoder.layers.5.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
723
+ "model.vision_model.vision_model.encoder.layers.5.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
724
+ "model.vision_model.vision_model.encoder.layers.5.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
725
+ "model.vision_model.vision_model.encoder.layers.6.layer_norm1.bias": "model-00001-of-00004.safetensors",
726
+ "model.vision_model.vision_model.encoder.layers.6.layer_norm1.weight": "model-00001-of-00004.safetensors",
727
+ "model.vision_model.vision_model.encoder.layers.6.layer_norm2.bias": "model-00001-of-00004.safetensors",
728
+ "model.vision_model.vision_model.encoder.layers.6.layer_norm2.weight": "model-00001-of-00004.safetensors",
729
+ "model.vision_model.vision_model.encoder.layers.6.mlp.fc1.bias": "model-00001-of-00004.safetensors",
730
+ "model.vision_model.vision_model.encoder.layers.6.mlp.fc1.weight": "model-00001-of-00004.safetensors",
731
+ "model.vision_model.vision_model.encoder.layers.6.mlp.fc2.bias": "model-00001-of-00004.safetensors",
732
+ "model.vision_model.vision_model.encoder.layers.6.mlp.fc2.weight": "model-00001-of-00004.safetensors",
733
+ "model.vision_model.vision_model.encoder.layers.6.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
734
+ "model.vision_model.vision_model.encoder.layers.6.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
735
+ "model.vision_model.vision_model.encoder.layers.6.self_attn.out_proj.bias": "model-00001-of-00004.safetensors",
736
+ "model.vision_model.vision_model.encoder.layers.6.self_attn.out_proj.weight": "model-00001-of-00004.safetensors",
737
+ "model.vision_model.vision_model.encoder.layers.6.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
738
+ "model.vision_model.vision_model.encoder.layers.6.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
739
+ "model.vision_model.vision_model.encoder.layers.6.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
740
+ "model.vision_model.vision_model.encoder.layers.6.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
741
+ "model.vision_model.vision_model.encoder.layers.7.layer_norm1.bias": "model-00001-of-00004.safetensors",
742
+ "model.vision_model.vision_model.encoder.layers.7.layer_norm1.weight": "model-00001-of-00004.safetensors",
743
+ "model.vision_model.vision_model.encoder.layers.7.layer_norm2.bias": "model-00001-of-00004.safetensors",
744
+ "model.vision_model.vision_model.encoder.layers.7.layer_norm2.weight": "model-00001-of-00004.safetensors",
745
+ "model.vision_model.vision_model.encoder.layers.7.mlp.fc1.bias": "model-00001-of-00004.safetensors",
746
+ "model.vision_model.vision_model.encoder.layers.7.mlp.fc1.weight": "model-00001-of-00004.safetensors",
747
+ "model.vision_model.vision_model.encoder.layers.7.mlp.fc2.bias": "model-00001-of-00004.safetensors",
748
+ "model.vision_model.vision_model.encoder.layers.7.mlp.fc2.weight": "model-00001-of-00004.safetensors",
749
+ "model.vision_model.vision_model.encoder.layers.7.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
750
+ "model.vision_model.vision_model.encoder.layers.7.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
751
+ "model.vision_model.vision_model.encoder.layers.7.self_attn.out_proj.bias": "model-00001-of-00004.safetensors",
752
+ "model.vision_model.vision_model.encoder.layers.7.self_attn.out_proj.weight": "model-00001-of-00004.safetensors",
753
+ "model.vision_model.vision_model.encoder.layers.7.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
754
+ "model.vision_model.vision_model.encoder.layers.7.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
755
+ "model.vision_model.vision_model.encoder.layers.7.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
756
+ "model.vision_model.vision_model.encoder.layers.7.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
757
+ "model.vision_model.vision_model.encoder.layers.8.layer_norm1.bias": "model-00001-of-00004.safetensors",
758
+ "model.vision_model.vision_model.encoder.layers.8.layer_norm1.weight": "model-00001-of-00004.safetensors",
759
+ "model.vision_model.vision_model.encoder.layers.8.layer_norm2.bias": "model-00001-of-00004.safetensors",
760
+ "model.vision_model.vision_model.encoder.layers.8.layer_norm2.weight": "model-00001-of-00004.safetensors",
761
+ "model.vision_model.vision_model.encoder.layers.8.mlp.fc1.bias": "model-00001-of-00004.safetensors",
762
+ "model.vision_model.vision_model.encoder.layers.8.mlp.fc1.weight": "model-00001-of-00004.safetensors",
763
+ "model.vision_model.vision_model.encoder.layers.8.mlp.fc2.bias": "model-00001-of-00004.safetensors",
764
+ "model.vision_model.vision_model.encoder.layers.8.mlp.fc2.weight": "model-00001-of-00004.safetensors",
765
+ "model.vision_model.vision_model.encoder.layers.8.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
766
+ "model.vision_model.vision_model.encoder.layers.8.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
767
+ "model.vision_model.vision_model.encoder.layers.8.self_attn.out_proj.bias": "model-00001-of-00004.safetensors",
768
+ "model.vision_model.vision_model.encoder.layers.8.self_attn.out_proj.weight": "model-00001-of-00004.safetensors",
769
+ "model.vision_model.vision_model.encoder.layers.8.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
770
+ "model.vision_model.vision_model.encoder.layers.8.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
771
+ "model.vision_model.vision_model.encoder.layers.8.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
772
+ "model.vision_model.vision_model.encoder.layers.8.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
773
+ "model.vision_model.vision_model.encoder.layers.9.layer_norm1.bias": "model-00001-of-00004.safetensors",
774
+ "model.vision_model.vision_model.encoder.layers.9.layer_norm1.weight": "model-00001-of-00004.safetensors",
775
+ "model.vision_model.vision_model.encoder.layers.9.layer_norm2.bias": "model-00001-of-00004.safetensors",
776
+ "model.vision_model.vision_model.encoder.layers.9.layer_norm2.weight": "model-00001-of-00004.safetensors",
777
+ "model.vision_model.vision_model.encoder.layers.9.mlp.fc1.bias": "model-00001-of-00004.safetensors",
778
+ "model.vision_model.vision_model.encoder.layers.9.mlp.fc1.weight": "model-00001-of-00004.safetensors",
779
+ "model.vision_model.vision_model.encoder.layers.9.mlp.fc2.bias": "model-00001-of-00004.safetensors",
780
+ "model.vision_model.vision_model.encoder.layers.9.mlp.fc2.weight": "model-00001-of-00004.safetensors",
781
+ "model.vision_model.vision_model.encoder.layers.9.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
782
+ "model.vision_model.vision_model.encoder.layers.9.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
783
+ "model.vision_model.vision_model.encoder.layers.9.self_attn.out_proj.bias": "model-00001-of-00004.safetensors",
784
+ "model.vision_model.vision_model.encoder.layers.9.self_attn.out_proj.weight": "model-00001-of-00004.safetensors",
785
+ "model.vision_model.vision_model.encoder.layers.9.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
786
+ "model.vision_model.vision_model.encoder.layers.9.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
787
+ "model.vision_model.vision_model.encoder.layers.9.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
788
+ "model.vision_model.vision_model.encoder.layers.9.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
789
+ "model.vision_model.vision_model.head.attention.in_proj_bias": "model-00001-of-00004.safetensors",
790
+ "model.vision_model.vision_model.head.attention.in_proj_weight": "model-00001-of-00004.safetensors",
791
+ "model.vision_model.vision_model.head.attention.out_proj.bias": "model-00001-of-00004.safetensors",
792
+ "model.vision_model.vision_model.head.attention.out_proj.weight": "model-00001-of-00004.safetensors",
793
+ "model.vision_model.vision_model.head.layernorm.bias": "model-00001-of-00004.safetensors",
794
+ "model.vision_model.vision_model.head.layernorm.weight": "model-00001-of-00004.safetensors",
795
+ "model.vision_model.vision_model.head.mlp.fc1.bias": "model-00001-of-00004.safetensors",
796
+ "model.vision_model.vision_model.head.mlp.fc1.weight": "model-00001-of-00004.safetensors",
797
+ "model.vision_model.vision_model.head.mlp.fc2.bias": "model-00001-of-00004.safetensors",
798
+ "model.vision_model.vision_model.head.mlp.fc2.weight": "model-00001-of-00004.safetensors",
799
+ "model.vision_model.vision_model.head.probe": "model-00001-of-00004.safetensors",
800
+ "model.vision_model.vision_model.post_layernorm.bias": "model-00001-of-00004.safetensors",
801
+ "model.vision_model.vision_model.post_layernorm.weight": "model-00001-of-00004.safetensors"
802
+ }
803
+ }
modeling_vmistral.py ADDED
@@ -0,0 +1,1766 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ """ PyTorch VMistral model."""
21
+ from dataclasses import dataclass
22
+ import inspect
23
+ import math
24
+ import warnings
25
+ from typing import List, Optional, Tuple, Union
26
+
27
+ import torch
28
+ import torch.nn.functional as F
29
+ import torch.utils.checkpoint
30
+ from torch import nn
31
+ from torch.nn import CrossEntropyLoss
32
+ from transformers.activations import ACT2FN
33
+ from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
34
+ from transformers.utils import (
35
+ add_start_docstrings,
36
+ add_start_docstrings_to_model_forward,
37
+ is_flash_attn_2_available,
38
+ replace_return_docstrings,
39
+ )
40
+
41
+ from einops import rearrange, repeat
42
+ from transformers import PreTrainedModel
43
+ from transformers.utils import logging
44
+ from transformers.modeling_outputs import ModelOutput
45
+
46
+ from .configuration_vmistral import VMistralConfig
47
+ from .vision import SiglipVisionModel
48
+
49
+
50
+ if is_flash_attn_2_available():
51
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
52
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
53
+
54
+ _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
55
+
56
+ logger = logging.get_logger(__name__)
57
+
58
+ _CONFIG_FOR_DOC = "VMistralConfig"
59
+
60
+ VMistral_PRETRAINED_MODEL_ARCHIVE_LIST = [
61
+ "HuggingFaceM4/VLM_WebSight_finetuned"
62
+ ]
63
+
64
+ @dataclass
65
+ class VMistralBaseModelOutputWithPast(ModelOutput):
66
+ """
67
+ Base class for VMistral model's outputs that may also contain a past key/values (to speed up sequential decoding).
68
+
69
+ Args:
70
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
71
+ Sequence of hidden-states at the output of the last layer of the model.
72
+
73
+ If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
74
+ hidden_size)` is output.
75
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
76
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
77
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
78
+ `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
79
+ encoder_sequence_length, embed_size_per_head)`.
80
+
81
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
82
+ `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
83
+ input) to speed up sequential decoding.
84
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
85
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
86
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
87
+
88
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
89
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
90
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
91
+ sequence_length)`.
92
+
93
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
94
+ heads.
95
+ image_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
96
+ Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images,
97
+ sequence_length, hidden_size)`.
98
+
99
+ image_hidden_states of the model produced by the vision encoder, and optionally by the perceiver
100
+ """
101
+
102
+ last_hidden_state: torch.FloatTensor = None
103
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
104
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
105
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
106
+ image_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
107
+
108
+
109
+ @dataclass
110
+ class VMistralCausalLMOutputWithPast(ModelOutput):
111
+ """
112
+ Base class for VMistral causal language model (or autoregressive) outputs.
113
+
114
+ Args:
115
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
116
+ Language modeling loss (for next-token prediction).
117
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
118
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
119
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
120
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
121
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
122
+
123
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
124
+ `past_key_values` input) to speed up sequential decoding.
125
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
126
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
127
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
128
+
129
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
130
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
131
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
132
+ sequence_length)`.
133
+
134
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
135
+ heads.
136
+ image_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
137
+ Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images,
138
+ sequence_length, hidden_size)`.
139
+
140
+ image_hidden_states of the model produced by the vision encoder, and optionally by the perceiver
141
+ """
142
+
143
+ loss: Optional[torch.FloatTensor] = None
144
+ logits: torch.FloatTensor = None
145
+ past_key_values: Optional[List[torch.FloatTensor]] = None
146
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
147
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
148
+ image_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
149
+
150
+
151
+ def expand_inputs_for_generation(
152
+ input_ids,
153
+ expand_size=1,
154
+ is_encoder_decoder=False,
155
+ attention_mask=None,
156
+ encoder_outputs=None,
157
+ **model_kwargs,
158
+ ):
159
+ expanded_return_idx = (
160
+ torch.arange(input_ids.shape[0]).view(-1, 1).repeat(1, expand_size).view(-1).to(input_ids.device)
161
+ )
162
+ input_ids = input_ids.index_select(0, expanded_return_idx)
163
+ model_kwargs["pixel_values"] = model_kwargs.get("pixel_values", None)
164
+ model_kwargs["image_hidden_states"] = model_kwargs.get("image_hidden_states", None)
165
+
166
+ if "token_type_ids" in model_kwargs:
167
+ token_type_ids = model_kwargs["token_type_ids"]
168
+ model_kwargs["token_type_ids"] = token_type_ids.index_select(0, expanded_return_idx)
169
+
170
+ if attention_mask is not None:
171
+ model_kwargs["attention_mask"] = attention_mask.index_select(0, expanded_return_idx)
172
+
173
+ if model_kwargs["pixel_values"] is not None:
174
+ model_kwargs["pixel_values"] = model_kwargs["pixel_values"].index_select(0, expanded_return_idx)
175
+
176
+ elif model_kwargs["image_hidden_states"] is not None:
177
+ model_kwargs["image_hidden_states"] = model_kwargs["image_hidden_states"].index_select(0, expanded_return_idx)
178
+
179
+ return input_ids, model_kwargs
180
+
181
+
182
+ def update_model_kwargs_for_generation(outputs, model_kwargs):
183
+ # must have this key set to at least None
184
+ if "past_key_values" in outputs:
185
+ model_kwargs["past_key_values"] = outputs.past_key_values
186
+ else:
187
+ model_kwargs["past_key_values"] = None
188
+
189
+ # update token_type_ids with last value
190
+ if "token_type_ids" in model_kwargs:
191
+ token_type_ids = model_kwargs["token_type_ids"]
192
+ model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1)
193
+
194
+ # update attention masks
195
+ if "attention_mask" in model_kwargs:
196
+ attention_mask = model_kwargs["attention_mask"]
197
+ model_kwargs["attention_mask"] = torch.cat(
198
+ [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
199
+ )
200
+
201
+ # Get the precomputed image_hidden_states
202
+ model_kwargs["image_hidden_states"] = outputs.image_hidden_states
203
+
204
+ return model_kwargs
205
+
206
+
207
+ def prepare_inputs_for_generation(input_ids, past_key_values=None, **kwargs):
208
+ token_type_ids = kwargs.get("token_type_ids", None)
209
+ # only last token for inputs_ids if past is defined in kwargs
210
+ if past_key_values:
211
+ input_ids = input_ids[:, -1].unsqueeze(-1)
212
+ if token_type_ids is not None:
213
+ token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
214
+
215
+ attention_mask = kwargs.get("attention_mask", None)
216
+ position_ids = kwargs.get("position_ids", None)
217
+
218
+ if attention_mask is not None and position_ids is None:
219
+ # create position_ids on the fly for batch generation
220
+ position_ids = attention_mask.long().cumsum(-1) - 1
221
+ position_ids.masked_fill_(attention_mask == 0, 1)
222
+ if past_key_values:
223
+ position_ids = position_ids[:, -1].unsqueeze(-1)
224
+
225
+ pixel_values = kwargs.get("pixel_values", None)
226
+ image_hidden_states = kwargs.get("image_hidden_states", None)
227
+
228
+ return {
229
+ "input_ids": input_ids,
230
+ "past_key_values": past_key_values,
231
+ "use_cache": kwargs.get("use_cache"),
232
+ "position_ids": position_ids,
233
+ "attention_mask": attention_mask,
234
+ "token_type_ids": token_type_ids,
235
+ "pixel_values": pixel_values,
236
+ "image_hidden_states": image_hidden_states,
237
+ }
238
+
239
+
240
+ def freeze_model(model, module_exceptions=[]):
241
+ mapping = {
242
+ "LayerNorm": nn.LayerNorm,
243
+ "Linear": nn.Linear,
244
+ "Embedding": nn.Embedding,
245
+ }
246
+ module_exceptions_mapped = [mapping[m] for m in module_exceptions]
247
+ for module in model.modules():
248
+ if module_exceptions and any([isinstance(module, t) for t in module_exceptions_mapped]):
249
+ module.requires_grad_(True) # Explicitly setting it to true to avoid any mistakes
250
+ else:
251
+ module.requires_grad_(False)
252
+ return model
253
+
254
+
255
+ class DecoupledEmbedding(nn.Embedding):
256
+ # Derived from https://pytorch.org/docs/stable/_modules/torch/nn/modules/sparse.html#Embedding
257
+ """
258
+ Implements a decoupling of parameters to allow freezing (or not) a subset of the embeddings.
259
+ In practise, the regular `weight` can be trained or frozen (i.e. `partially_freeze=True`), and if `num_additional_embeddings` > 0, then it will create `num_additional_embeddings` additional parameters that are always trained.
260
+ If `num_additional_embeddings=0`, then the module defaults back to the regular behavior of `nn.Embedding`.
261
+ """
262
+
263
+ def __init__(
264
+ self,
265
+ num_embeddings,
266
+ num_additional_embeddings,
267
+ embedding_dim,
268
+ partially_freeze=False,
269
+ device=None,
270
+ dtype=None,
271
+ padding_idx=None,
272
+ **kwargs,
273
+ ) -> None:
274
+ """
275
+ num_additional_embeddings: int. Number of additional embeddings. Only useful when you `partially_freeze=True`.
276
+ partially_freeze: bool. If True, the regular `weight` will be frozen. `additional_weight` is never frozen.
277
+
278
+ Note: there are a lot of other parameters to initialize a standard `nn.Embedding` such as `padding_idx`, `max_norm` or `norm_type`. We are not supporting these.
279
+ """
280
+ if padding_idx is not None and padding_idx > num_embeddings:
281
+ raise ValueError(f"padding_idx must be within num_embeddings. Got {padding_idx} and {num_embeddings}")
282
+ super().__init__(
283
+ num_embeddings=num_embeddings,
284
+ embedding_dim=embedding_dim,
285
+ device=device,
286
+ dtype=dtype,
287
+ padding_idx=padding_idx,
288
+ **kwargs,
289
+ )
290
+ self.num_embeddings = num_embeddings
291
+ self.padding_idx = padding_idx
292
+ self.num_additional_embeddings = num_additional_embeddings
293
+ self.partially_freeze = partially_freeze
294
+
295
+ if partially_freeze:
296
+ self.weight.requires_grad_(False)
297
+
298
+ if self.num_additional_embeddings > 0:
299
+ self.additional_embedding = nn.Embedding(
300
+ num_embeddings=self.num_additional_embeddings,
301
+ embedding_dim=embedding_dim,
302
+ device=device,
303
+ dtype=dtype,
304
+ )
305
+
306
+ def forward(self, input_ids):
307
+ """
308
+ we have 2 embeddings, with different indices - one pretrained self.weight and another
309
+ self.additional_embedding.weight that is being trained.
310
+
311
+ in order to make a lookup of the input ids, we:
312
+ 1. find out the indices of the entries belonging to the 2nd embedding
313
+ 2. extract those values while subtracting the size of the first embedding (num_embeddings),
314
+ since the 2nd embedding starts from 0 and not num_embeddings
315
+ 3. perform the 2nd embedding lookup
316
+ 4. now we handle the 1st embedding, we overwrite indices belonging to the 2nd embedding with a padding index
317
+ 5. perform the 1st embedding lookup
318
+ 6. now we overwrite the values in the 1st embedding lookup with the values of the 2nd embedding lookup
319
+
320
+ note: for the 1st embedding lookup we could have looked up only the low indices and not do
321
+ the padding, but then we have to create a new tensor and populate it with 2 tensors that are
322
+ spread out across various indices - i.e. not a simple concat - I haven't benchmarked the
323
+ complex case if it's any faster, given that seqlens are usually relatively short it's
324
+ probably not faster or if faster not by much - but might be a good idea to measure.
325
+
326
+ """
327
+ if self.num_additional_embeddings == 0:
328
+ return self.additional_embedding(input_ids)
329
+
330
+ # Clone so that we don't modify the original input_ids later on
331
+ input_ids = input_ids.clone()
332
+ additional_vocab_indices = torch.where(input_ids >= self.num_embeddings)
333
+ input_ids_additional_vocab = input_ids[additional_vocab_indices]
334
+ additional_embeddings = self.additional_embedding(input_ids_additional_vocab - self.num_embeddings)
335
+
336
+ # for successful lookup replace input_ids with 0, the results of these will be discarded anyway
337
+ input_ids[additional_vocab_indices] = 0
338
+ full_vector = F.embedding(input_ids, self.weight)
339
+
340
+ # overwrite the records with high indices
341
+ full_vector[additional_vocab_indices] = additional_embeddings
342
+
343
+ return full_vector
344
+
345
+ def extra_repr(self) -> str:
346
+ return "num_embeddings={}, num_additional_embeddings={}, embedding_dim={}, partially_freeze={}".format(
347
+ self.num_embeddings,
348
+ self.num_additional_embeddings,
349
+ self.embedding_dim,
350
+ self.partially_freeze,
351
+ )
352
+
353
+
354
+ class DecoupledLinear(nn.Linear):
355
+ # Derived from https://pytorch.org/docs/stable/_modules/torch/nn/modules/linear.html#Linear
356
+ """
357
+ Implements a decoupling of parameters to allow freezing (or not) a subset of the parameters.
358
+ In practise, the regular `weight` can be trained or frozen (i.e. `partially_freeze=True`), and if `out_additional_features` > 0, then it will create `out_additional_features * in_features` additional parameters that are always trained.
359
+ If `out_additional_features=0`, then the module defaults back to the regular behavior of `nn.Linear`.
360
+ """
361
+
362
+ def __init__(
363
+ self,
364
+ in_features: int,
365
+ out_features: int,
366
+ out_additional_features: int = 0,
367
+ bias: bool = True,
368
+ partially_freeze: bool = True,
369
+ device=None,
370
+ dtype=None,
371
+ ) -> None:
372
+ """
373
+ out_additional_features: int. Number of additional trainable dimensions. Only makes sense when `partially_freeze=True`.
374
+ partially_freeze: bool. If True, the regular `weight` will be frozen and extra parameters (if any) will be trainable. If False, default to the regular behavior of nn.Linear.
375
+ """
376
+ super().__init__(in_features, out_features, bias, device, dtype)
377
+ self.out_additional_features = out_additional_features
378
+ self.partially_freeze = partially_freeze
379
+
380
+ self.in_features = in_features
381
+ self.out_features = out_features
382
+
383
+ if partially_freeze:
384
+ self.weight.requires_grad_(False)
385
+ if bias:
386
+ self.bias.requires_grad_(False)
387
+
388
+ if out_additional_features > 0:
389
+ self.additional_fc = nn.Linear(
390
+ in_features=in_features,
391
+ out_features=out_additional_features,
392
+ bias=bias,
393
+ device=device,
394
+ dtype=dtype,
395
+ )
396
+
397
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
398
+ output = F.linear(input, self.weight, self.bias)
399
+
400
+ if self.out_additional_features > 0:
401
+ additional_features = self.additional_fc(input)
402
+ output = torch.cat((output, additional_features), -1)
403
+
404
+ return output
405
+
406
+ def extra_repr(self) -> str:
407
+ """Overwriting `nn.Linear.extra_repr` to include new parameters."""
408
+ return "in_features={}, out_features={}, out_additional_features={}, bias={}, partially_freeze={}".format(
409
+ self.in_features,
410
+ self.out_features,
411
+ self.out_additional_features,
412
+ self.bias is not None,
413
+ self.partially_freeze,
414
+ )
415
+
416
+
417
+ class SwiGLU(nn.Module):
418
+ def __init__(self, embed_dim) -> None:
419
+ super().__init__()
420
+ self.fc1 = nn.Linear(embed_dim, embed_dim, bias=False)
421
+ self.fc2 = nn.Linear(embed_dim, embed_dim, bias=False)
422
+
423
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
424
+ x_1 = self.fc1(x)
425
+ x_1 = torch.mul(x_1, torch.sigmoid(x_1))
426
+ x_2 = self.fc2(x)
427
+ x = torch.mul(x_1, x_2)
428
+ return x
429
+
430
+
431
+ class ModalityProjection(nn.Module):
432
+ def __init__(self, embed_dim_in, embed_dim_out) -> None:
433
+ super().__init__()
434
+ self.fc1 = nn.Linear(embed_dim_in, embed_dim_out, bias=False)
435
+ self.act = SwiGLU(embed_dim_out)
436
+ self.fc2 = nn.Linear(embed_dim_out, embed_dim_out, bias=False)
437
+
438
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
439
+ x = self.fc1(x)
440
+ x = self.act(x)
441
+ x = self.fc2(x)
442
+ return x
443
+
444
+
445
+ class PerceiverResampler(nn.Module):
446
+ def __init__(
447
+ self, embed_dim: int, depth: int, n_heads: int, head_dim: int, n_latents: int, qk_layer_norms: bool
448
+ ) -> None:
449
+ """
450
+ Instantiates a Perceiver Resampler that operates over a sequence of embeddings (say from a ResNet or ViT or
451
+ MAE) of a given dimension, performs `depth` blocks of cross-attention with a fixed `n_latents` inputs, then
452
+ returns a Tensor of shape [bsz, n_latents, embed_dim].
453
+ :param embed_dim: Dimensionality of embeddings being fed to the Perceiver Resampler (also dimensionality of
454
+ latent embeddings *returned* by the Perceiver Resampler. Could be e.g., VIT embed_dim, ResNet
455
+ pool dim, and so on.
456
+ :param depth: Depth of the Perceiver Resampler (Transformer w/ cross attention). Should be shallow (< 3).
457
+ :param n_heads: Number of heads in each Transformer block (for multi-headed self-attention).
458
+ :param head_dim: Dimensionality of each head projection in the Transformer block.
459
+ :param n_latents: Number of latent embeddings to resample ("compress") the input sequence to (usually < 128).
460
+ """
461
+ super().__init__()
462
+ self.embed_dim, self.n_heads, self.head_dim, self.n_latents = embed_dim, n_heads, head_dim, n_latents
463
+ self.qk_layer_norms = qk_layer_norms
464
+
465
+ # Create Latents for Perceiver
466
+ self.latents = nn.Parameter(torch.ones(self.n_latents, self.embed_dim))
467
+
468
+ self.intermediate_dim = self.embed_dim * 4
469
+ # Create Transformer Blocks
470
+ self.blocks = nn.ModuleList(
471
+ [
472
+ nn.ModuleList(
473
+ [
474
+ PerceiverAttention(self.embed_dim, self.n_heads, self.head_dim, self.qk_layer_norms),
475
+ MLP(self.embed_dim, self.intermediate_dim),
476
+ ]
477
+ )
478
+ for _ in range(depth)
479
+ ]
480
+ )
481
+ self.layer_norm = nn.LayerNorm(self.embed_dim)
482
+
483
+ def forward(self, context: torch.Tensor) -> torch.Tensor:
484
+ """Resample arbitrary length context & *compress* down to self.n_latents latent embeddings"""
485
+ latents = repeat(self.latents, "seq embed -> bsz seq embed", bsz=context.shape[0])
486
+
487
+ # Feed through Perceiver Attention blocks...
488
+ for attn, ff in self.blocks:
489
+ latents = attn(context, latents) + latents
490
+ latents = ff(latents) + latents
491
+
492
+ return self.layer_norm(latents)
493
+
494
+
495
+ class PerceiverAttention(nn.Module):
496
+ def __init__(self, embed_dim: int, n_heads: int, head_dim: int, qk_layer_norms: bool) -> None:
497
+ """Perceiver Cross-Attention Module --> let long-form inputs be `context`, resampled embeddings be `latents`"""
498
+ super().__init__()
499
+ self.embed_dim, self.n_heads, self.head_dim = embed_dim, n_heads, head_dim
500
+ self.qk_layer_norms = qk_layer_norms
501
+ # Normalization & Scaling
502
+ self.context_layer_norm = nn.LayerNorm(self.embed_dim)
503
+ self.latents_layer_norm = nn.LayerNorm(self.embed_dim)
504
+ if self.qk_layer_norms:
505
+ self.q_layer_norm = nn.LayerNorm(self.head_dim)
506
+ self.k_layer_norm = nn.LayerNorm(self.head_dim)
507
+
508
+ self.qk_scale = self.head_dim**-0.5
509
+
510
+ # Q, K, V Projection (no bias -- detail from Perceiver/Flamingo Papers).
511
+ self.q_proj = nn.Linear(self.embed_dim, self.n_heads * self.head_dim, bias=False)
512
+ self.k_proj = nn.Linear(self.embed_dim, self.n_heads * self.head_dim, bias=False)
513
+ self.v_proj = nn.Linear(self.embed_dim, self.n_heads * self.head_dim, bias=False)
514
+
515
+ self.output_proj = nn.Linear(self.n_heads * self.head_dim, self.embed_dim, bias=False)
516
+
517
+ def forward(self, context: torch.Tensor, latents: torch.Tensor) -> torch.Tensor:
518
+ """
519
+ Runs Perceiver Self-Attention, with special (context, latents) appended along the `seq` dimension!
520
+ :param context: Tensor of shape [bsz, seq, embed_dim] representing long-form context to resample.
521
+ :param latents: Tensor of shape [bsz, n_latents, embed_dim] representing fixed length latents to compress to.
522
+ :return: Tensor of shape [bsz, n_latents, embed_dim] representing attention over latents w/ cross from context.
523
+ """
524
+ context = self.context_layer_norm(context)
525
+ latents = self.latents_layer_norm(latents)
526
+
527
+ # Query, Key, Value Projections --> Note that in Flamingo, latents are *concatenated* with context prior to attn!
528
+ # Note: This results in queries w/ `seq = n_latents`, and keys, values with `seq = len(context) + n_latents`
529
+ q = self.q_proj(latents)
530
+ k = self.k_proj(torch.cat([context, latents], dim=-2))
531
+ v = self.v_proj(torch.cat([context, latents], dim=-2))
532
+
533
+ # Multiheaded Self-Attention w/ stable softmax (subtract per-row max -- `amax` -- before softmax call)
534
+ # =>> `attn` should be a 2D matrix of shape [n_latents x (context + n_latents)]
535
+ q, k, v = [rearrange(x, "bsz seq (heads embed) -> bsz heads seq embed", heads=self.n_heads) for x in (q, k, v)]
536
+ if self.qk_layer_norms:
537
+ q = self.q_layer_norm(q)
538
+ k = self.k_layer_norm(k)
539
+
540
+ scores = torch.einsum("... i d, ... j d -> ... i j", q * self.qk_scale, k)
541
+ stabilized_scores = scores - (scores.amax(dim=-1, keepdim=True).detach())
542
+ attn = stabilized_scores.softmax(dim=-1)
543
+
544
+ # Attend & project back to output...
545
+ resampled = torch.einsum("... i j, ... j d -> ... i d", attn, v)
546
+ return self.output_proj(
547
+ rearrange(resampled, "bsz heads seq embed -> bsz seq (heads embed)", heads=self.n_heads)
548
+ )
549
+
550
+
551
+ class MLP(nn.Module):
552
+ def __init__(self, embed_dim, intermediate_size):
553
+ """Simple MLP block with intermediate_size and embedding size"""
554
+ super().__init__()
555
+ self.embed_dim = embed_dim
556
+ self.ln = nn.LayerNorm(self.embed_dim)
557
+ self.fc = nn.Linear(self.embed_dim, intermediate_size, bias=False)
558
+ self.act = nn.ReLU()
559
+ self.c_proj = nn.Linear(intermediate_size, self.embed_dim, bias=False)
560
+
561
+ def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor:
562
+ hidden_states = self.ln(hidden_states)
563
+ hidden_states = self.fc(hidden_states)
564
+ hidden_states = self.act(hidden_states)
565
+ hidden_states = self.c_proj(hidden_states)
566
+
567
+ return hidden_states
568
+
569
+
570
+ # Copied from transformers.models.llama.modeling_llama._get_unpad_data
571
+ def _get_unpad_data(attention_mask):
572
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
573
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
574
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
575
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
576
+ return (
577
+ indices,
578
+ cu_seqlens,
579
+ max_seqlen_in_batch,
580
+ )
581
+
582
+
583
+ # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Mistral
584
+ class MistralRMSNorm(nn.Module):
585
+ def __init__(self, hidden_size, eps=1e-6):
586
+ """
587
+ MistralRMSNorm is equivalent to T5LayerNorm
588
+ """
589
+ super().__init__()
590
+ self.weight = nn.Parameter(torch.ones(hidden_size))
591
+ self.variance_epsilon = eps
592
+
593
+ def forward(self, hidden_states):
594
+ input_dtype = hidden_states.dtype
595
+ hidden_states = hidden_states.to(torch.float32)
596
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
597
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
598
+ return self.weight * hidden_states.to(input_dtype)
599
+
600
+
601
+ # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Mistral
602
+ class MistralRotaryEmbedding(nn.Module):
603
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
604
+ super().__init__()
605
+
606
+ self.dim = dim
607
+ self.max_position_embeddings = max_position_embeddings
608
+ self.base = base
609
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
610
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
611
+
612
+ # Build here to make `torch.jit.trace` work.
613
+ self._set_cos_sin_cache(
614
+ seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
615
+ )
616
+
617
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
618
+ self.max_seq_len_cached = seq_len
619
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
620
+
621
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
622
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
623
+ emb = torch.cat((freqs, freqs), dim=-1)
624
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
625
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
626
+
627
+ def forward(self, x, seq_len=None):
628
+ # x: [bs, num_attention_heads, seq_len, head_size]
629
+ if seq_len > self.max_seq_len_cached:
630
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
631
+
632
+ return (
633
+ self.cos_cached[:seq_len].to(dtype=x.dtype),
634
+ self.sin_cached[:seq_len].to(dtype=x.dtype),
635
+ )
636
+
637
+
638
+ # Copied from transformers.models.llama.modeling_llama.rotate_half
639
+ def rotate_half(x):
640
+ """Rotates half the hidden dims of the input."""
641
+ x1 = x[..., : x.shape[-1] // 2]
642
+ x2 = x[..., x.shape[-1] // 2 :]
643
+ return torch.cat((-x2, x1), dim=-1)
644
+
645
+
646
+ # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
647
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
648
+ cos = cos[position_ids].unsqueeze(1) # [seq_len, dim] -> [batch_size, 1, seq_len, head_dim]
649
+ sin = sin[position_ids].unsqueeze(1)
650
+ q_embed = (q * cos) + (rotate_half(q) * sin)
651
+ k_embed = (k * cos) + (rotate_half(k) * sin)
652
+ return q_embed, k_embed
653
+
654
+
655
+ class MistralMLP(nn.Module):
656
+ def __init__(self, config):
657
+ super().__init__()
658
+ self.config = config
659
+ self.hidden_size = config.hidden_size
660
+ self.intermediate_size = config.intermediate_size
661
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
662
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
663
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
664
+ self.act_fn = ACT2FN[config.hidden_act]
665
+
666
+ def forward(self, x):
667
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
668
+
669
+
670
+ # Copied from transformers.models.llama.modeling_llama.repeat_kv
671
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
672
+ """
673
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
674
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
675
+ """
676
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
677
+ if n_rep == 1:
678
+ return hidden_states
679
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
680
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
681
+
682
+
683
+ class MistralAttention(nn.Module):
684
+ """
685
+ Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
686
+ and "Generating Long Sequences with Sparse Transformers".
687
+ """
688
+
689
+ def __init__(self, config: VMistralConfig, qk_layer_norms: bool = False):
690
+ super().__init__()
691
+ self.config = config
692
+ self.hidden_size = config.hidden_size
693
+ self.num_heads = config.num_attention_heads
694
+ self.head_dim = self.hidden_size // self.num_heads
695
+ self.num_key_value_heads = config.num_key_value_heads
696
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
697
+ self.max_position_embeddings = config.max_position_embeddings
698
+ self.rope_theta = config.rope_theta
699
+ self.is_causal = True
700
+
701
+ if (self.head_dim * self.num_heads) != self.hidden_size:
702
+ raise ValueError(
703
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
704
+ f" and `num_heads`: {self.num_heads})."
705
+ )
706
+
707
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
708
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
709
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
710
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
711
+
712
+ self.qk_layer_norms = qk_layer_norms
713
+ if self.qk_layer_norms:
714
+ self.q_layer_norm = MistralRMSNorm(self.head_dim, eps=config.rms_norm_eps)
715
+ self.k_layer_norm = MistralRMSNorm(self.head_dim, eps=config.rms_norm_eps)
716
+
717
+ self.rotary_emb = MistralRotaryEmbedding(
718
+ self.head_dim,
719
+ max_position_embeddings=self.max_position_embeddings,
720
+ base=self.rope_theta,
721
+ )
722
+ self.attention_dropout = config.attention_dropout
723
+
724
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
725
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
726
+
727
+ def forward(
728
+ self,
729
+ hidden_states: torch.Tensor,
730
+ key_value_states: Optional[torch.Tensor] = None,
731
+ attention_mask: Optional[torch.Tensor] = None,
732
+ position_ids: Optional[torch.LongTensor] = None,
733
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
734
+ output_attentions: bool = False,
735
+ use_cache: bool = False,
736
+ **kwargs,
737
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
738
+ if "padding_mask" in kwargs:
739
+ warnings.warn(
740
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use"
741
+ " `attention_mask` instead.`"
742
+ )
743
+
744
+ bsz, q_len, _ = hidden_states.size()
745
+
746
+ query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
747
+ key_states = (
748
+ self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
749
+ )
750
+ value_states = (
751
+ self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
752
+ )
753
+
754
+ kv_seq_len = key_states.shape[-2]
755
+ if past_key_value is not None:
756
+ kv_seq_len += past_key_value[0].shape[-2]
757
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
758
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
759
+
760
+ if past_key_value is not None:
761
+ # reuse k, v, self_attention
762
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
763
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
764
+
765
+ past_key_value = (key_states, value_states) if use_cache else None
766
+
767
+ if self.qk_layer_norms:
768
+ query_states = self.q_layer_norm(query_states)
769
+ key_states = self.k_layer_norm(key_states)
770
+
771
+ # repeat k/v heads if n_kv_heads < n_heads
772
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
773
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
774
+
775
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
776
+
777
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
778
+ raise ValueError(
779
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
780
+ f" {attn_weights.size()}"
781
+ )
782
+
783
+ if attention_mask is not None:
784
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
785
+ raise ValueError(
786
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
787
+ )
788
+
789
+ attn_weights = attn_weights + attention_mask
790
+
791
+ # upcast attention to fp32
792
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
793
+ attn_output = torch.matmul(attn_weights, value_states)
794
+
795
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
796
+ raise ValueError(
797
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
798
+ f" {attn_output.size()}"
799
+ )
800
+
801
+ attn_output = attn_output.transpose(1, 2).contiguous()
802
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
803
+
804
+ attn_output = self.o_proj(attn_output)
805
+
806
+ if not output_attentions:
807
+ attn_weights = None
808
+
809
+ return attn_output, attn_weights, past_key_value
810
+
811
+
812
+ class MistralFlashAttention2(MistralAttention):
813
+ """
814
+ Mistral flash attention module. This module inherits from `MistralAttention` as the weights of the module stays
815
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
816
+ flash attention and deal with padding tokens in case the input contains any of them.
817
+ """
818
+
819
+ def forward(
820
+ self,
821
+ hidden_states: torch.Tensor,
822
+ attention_mask: Optional[torch.Tensor] = None,
823
+ position_ids: Optional[torch.LongTensor] = None,
824
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
825
+ output_attentions: bool = False,
826
+ use_cache: bool = False,
827
+ **kwargs,
828
+ ):
829
+ if "padding_mask" in kwargs:
830
+ warnings.warn(
831
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use"
832
+ " `attention_mask` instead.`"
833
+ )
834
+
835
+ # overwrite attention_mask with padding_mask
836
+ attention_mask = kwargs.pop("padding_mask")
837
+ bsz, q_len, _ = hidden_states.size()
838
+
839
+ query_states = self.q_proj(hidden_states)
840
+ key_states = self.k_proj(hidden_states)
841
+ value_states = self.v_proj(hidden_states)
842
+
843
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
844
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
845
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
846
+
847
+ kv_seq_len = key_states.shape[-2]
848
+ if past_key_value is not None:
849
+ kv_seq_len += past_key_value[0].shape[-2]
850
+
851
+ # Because the input can be padded, the absolute sequence length depends on the max position id.
852
+ rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
853
+ cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)
854
+
855
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
856
+
857
+ use_sliding_windows = False
858
+ # use_sliding_windows = (
859
+ # _flash_supports_window_size
860
+ # and hasattr(self.config, "sliding_window") is not None
861
+ # and kv_seq_len > self.config.sliding_window
862
+ # )
863
+ _flash_supports_window_size = None
864
+
865
+ if not _flash_supports_window_size:
866
+ logger.warning_once(
867
+ "The current flash attention version does not support sliding window attention, for a more memory"
868
+ " efficient implementation make sure to upgrade flash-attn library."
869
+ )
870
+
871
+ if past_key_value is not None:
872
+ # Activate slicing cache only if the config has a value `sliding_windows` attribute
873
+ if hasattr(self.config, "sliding_window") and kv_seq_len > self.config.sliding_window:
874
+ slicing_tokens = kv_seq_len - self.config.sliding_window
875
+
876
+ past_key = past_key_value[0]
877
+ past_value = past_key_value[1]
878
+
879
+ past_key = past_key[:, :, slicing_tokens:, :].contiguous()
880
+ past_value = past_value[:, :, slicing_tokens:, :].contiguous()
881
+
882
+ if past_key.shape[-2] != self.config.sliding_window - 1:
883
+ raise ValueError(
884
+ "past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1,"
885
+ f" head_dim`), got {past_key.shape}"
886
+ )
887
+
888
+ past_key_value = (past_key, past_value)
889
+
890
+ if attention_mask is not None:
891
+ attention_mask = attention_mask[:, slicing_tokens:]
892
+ attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
893
+
894
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
895
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
896
+
897
+ past_key_value = (key_states, value_states) if use_cache else None
898
+
899
+ # repeat k/v heads if n_kv_heads < n_heads
900
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
901
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
902
+ dropout_rate = 0.0 if not self.training else self.attention_dropout
903
+
904
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
905
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
906
+ # cast them back in float16 just to be sure everything works as expected.
907
+ input_dtype = query_states.dtype
908
+ if input_dtype == torch.float32:
909
+ # Handle the case where the model is quantized
910
+ if hasattr(self.config, "_pre_quantization_dtype"):
911
+ target_dtype = self.config._pre_quantization_dtype
912
+ else:
913
+ target_dtype = self.q_proj.weight.dtype
914
+
915
+ logger.warning_once(
916
+ "The input hidden states seems to be silently casted in float32, this might be related to the fact"
917
+ " you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
918
+ f" {target_dtype}."
919
+ )
920
+
921
+ query_states = query_states.to(target_dtype)
922
+ key_states = key_states.to(target_dtype)
923
+ value_states = value_states.to(target_dtype)
924
+
925
+ # Reashape to the expected shape for Flash Attention
926
+ query_states = query_states.transpose(1, 2)
927
+ key_states = key_states.transpose(1, 2)
928
+ value_states = value_states.transpose(1, 2)
929
+
930
+ attn_output = self._flash_attention_forward(
931
+ query_states,
932
+ key_states,
933
+ value_states,
934
+ attention_mask,
935
+ q_len,
936
+ dropout=dropout_rate,
937
+ use_sliding_windows=use_sliding_windows,
938
+ )
939
+
940
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
941
+ attn_output = self.o_proj(attn_output)
942
+
943
+ if not output_attentions:
944
+ attn_weights = None
945
+
946
+ return attn_output, attn_weights, past_key_value
947
+
948
+ def _flash_attention_forward(
949
+ self,
950
+ query_states,
951
+ key_states,
952
+ value_states,
953
+ attention_mask,
954
+ query_length,
955
+ dropout=0.0,
956
+ softmax_scale=None,
957
+ use_sliding_windows=False,
958
+ ):
959
+ """
960
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
961
+ first unpad the input, then computes the attention scores and pad the final attention scores.
962
+
963
+ Args:
964
+ query_states (`torch.Tensor`):
965
+ Input query states to be passed to Flash Attention API
966
+ key_states (`torch.Tensor`):
967
+ Input key states to be passed to Flash Attention API
968
+ value_states (`torch.Tensor`):
969
+ Input value states to be passed to Flash Attention API
970
+ attention_mask (`torch.Tensor`):
971
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
972
+ position of padding tokens and 1 for the position of non-padding tokens.
973
+ dropout (`int`, *optional*):
974
+ Attention dropout
975
+ softmax_scale (`float`, *optional*):
976
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
977
+ use_sliding_windows (`bool`, *optional*):
978
+ Whether to activate sliding window attention.
979
+ """
980
+ # Contains at least one padding token in the sequence
981
+ if attention_mask is not None:
982
+ batch_size = query_states.shape[0]
983
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
984
+ query_states, key_states, value_states, attention_mask, query_length
985
+ )
986
+
987
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
988
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
989
+
990
+ if not use_sliding_windows:
991
+ attn_output_unpad = flash_attn_varlen_func(
992
+ query_states,
993
+ key_states,
994
+ value_states,
995
+ cu_seqlens_q=cu_seqlens_q,
996
+ cu_seqlens_k=cu_seqlens_k,
997
+ max_seqlen_q=max_seqlen_in_batch_q,
998
+ max_seqlen_k=max_seqlen_in_batch_k,
999
+ dropout_p=dropout,
1000
+ softmax_scale=softmax_scale,
1001
+ causal=self.is_causal,
1002
+ )
1003
+ else:
1004
+ attn_output_unpad = flash_attn_varlen_func(
1005
+ query_states,
1006
+ key_states,
1007
+ value_states,
1008
+ cu_seqlens_q=cu_seqlens_q,
1009
+ cu_seqlens_k=cu_seqlens_k,
1010
+ max_seqlen_q=max_seqlen_in_batch_q,
1011
+ max_seqlen_k=max_seqlen_in_batch_k,
1012
+ dropout_p=dropout,
1013
+ softmax_scale=softmax_scale,
1014
+ causal=self.is_causal,
1015
+ window_size=(self.config.sliding_window, self.config.sliding_window),
1016
+ )
1017
+
1018
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
1019
+ else:
1020
+ if not use_sliding_windows:
1021
+ attn_output = flash_attn_func(
1022
+ query_states,
1023
+ key_states,
1024
+ value_states,
1025
+ dropout,
1026
+ softmax_scale=softmax_scale,
1027
+ causal=self.is_causal,
1028
+ )
1029
+ else:
1030
+ attn_output = flash_attn_func(
1031
+ query_states,
1032
+ key_states,
1033
+ value_states,
1034
+ dropout,
1035
+ softmax_scale=softmax_scale,
1036
+ causal=self.is_causal,
1037
+ window_size=(self.config.sliding_window, self.config.sliding_window),
1038
+ )
1039
+
1040
+ return attn_output
1041
+
1042
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
1043
+ batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
1044
+
1045
+ # On the first iteration we need to properly re-create the padding mask
1046
+ # by slicing it on the proper place
1047
+ if kv_seq_len != attention_mask.shape[-1]:
1048
+ attention_mask_num_tokens = attention_mask.shape[-1]
1049
+ attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :]
1050
+
1051
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
1052
+
1053
+ key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
1054
+ value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
1055
+
1056
+ if query_length == kv_seq_len:
1057
+ query_layer = index_first_axis(
1058
+ query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k
1059
+ )
1060
+ cu_seqlens_q = cu_seqlens_k
1061
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
1062
+ indices_q = indices_k
1063
+ elif query_length == 1:
1064
+ max_seqlen_in_batch_q = 1
1065
+ cu_seqlens_q = torch.arange(
1066
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
1067
+ ) # There is a memcpy here, that is very bad.
1068
+ indices_q = cu_seqlens_q[:-1]
1069
+ query_layer = query_layer.squeeze(1)
1070
+ else:
1071
+ # The -q_len: slice assumes left padding.
1072
+ attention_mask = attention_mask[:, -query_length:]
1073
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
1074
+
1075
+ return (
1076
+ query_layer,
1077
+ key_layer,
1078
+ value_layer,
1079
+ indices_q,
1080
+ (cu_seqlens_q, cu_seqlens_k),
1081
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
1082
+ )
1083
+
1084
+
1085
+ class MistralDecoderLayer(nn.Module):
1086
+ def __init__(self, config: VMistralConfig):
1087
+ super().__init__()
1088
+ self.hidden_size = config.hidden_size
1089
+ self.self_attn = (
1090
+ MistralAttention(config=config)
1091
+ )
1092
+ self.mlp = MistralMLP(config)
1093
+ self.input_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1094
+ self.post_attention_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1095
+
1096
+ def forward(
1097
+ self,
1098
+ hidden_states: torch.Tensor,
1099
+ attention_mask: Optional[torch.Tensor] = None,
1100
+ position_ids: Optional[torch.LongTensor] = None,
1101
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
1102
+ output_attentions: Optional[bool] = False,
1103
+ use_cache: Optional[bool] = False,
1104
+ **kwargs,
1105
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
1106
+ if "padding_mask" in kwargs:
1107
+ warnings.warn(
1108
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use"
1109
+ " `attention_mask` instead.`"
1110
+ )
1111
+ """
1112
+ Args:
1113
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
1114
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
1115
+ `(batch, sequence_length)` where padding elements are indicated by 0.
1116
+ output_attentions (`bool`, *optional*):
1117
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1118
+ returned tensors for more detail.
1119
+ use_cache (`bool`, *optional*):
1120
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
1121
+ (see `past_key_values`).
1122
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
1123
+ """
1124
+
1125
+ residual = hidden_states
1126
+
1127
+ hidden_states = self.input_layernorm(hidden_states)
1128
+
1129
+ # Self Attention
1130
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
1131
+ hidden_states=hidden_states,
1132
+ attention_mask=attention_mask,
1133
+ position_ids=position_ids,
1134
+ past_key_value=past_key_value,
1135
+ output_attentions=output_attentions,
1136
+ use_cache=use_cache,
1137
+ )
1138
+ hidden_states = residual + hidden_states
1139
+
1140
+ # Fully Connected
1141
+ residual = hidden_states
1142
+ hidden_states = self.post_attention_layernorm(hidden_states)
1143
+ hidden_states = self.mlp(hidden_states)
1144
+ hidden_states = residual + hidden_states
1145
+
1146
+ outputs = (hidden_states,)
1147
+
1148
+ if output_attentions:
1149
+ outputs += (self_attn_weights,)
1150
+
1151
+ if use_cache:
1152
+ outputs += (present_key_value,)
1153
+
1154
+ return outputs
1155
+
1156
+
1157
+ MISTRAL_START_DOCSTRING = r"""
1158
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
1159
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
1160
+ etc.)
1161
+
1162
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
1163
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
1164
+ and behavior.
1165
+
1166
+ Parameters:
1167
+ config ([`VMistralConfig`]):
1168
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
1169
+ load the weights associated with the model, only the configuration. Check out the
1170
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
1171
+ """
1172
+
1173
+
1174
+ @add_start_docstrings(
1175
+ "The bare Mistral Model outputting raw hidden-states without any specific head on top.",
1176
+ MISTRAL_START_DOCSTRING,
1177
+ )
1178
+ class VMistralPreTrainedModel(PreTrainedModel):
1179
+ config_class = VMistralConfig
1180
+ base_model_prefix = "model"
1181
+ supports_gradient_checkpointing = True
1182
+ _no_split_modules = ["MistralDecoderLayer"]
1183
+ _skip_keys_device_placement = "past_key_values"
1184
+ _supports_sdpa = False
1185
+
1186
+ def _init_weights(self, module):
1187
+ # important: this ported version of the model isn't meant for training from scratch - only
1188
+ # inference and fine-tuning - so the proper init weights code has been removed - the m4 code
1189
+ # base should be used for training from scratch and it contains the correct code.
1190
+ std = self.config.initializer_range
1191
+ if isinstance(module, nn.Linear):
1192
+ module.weight.data.normal_(mean=0.0, std=std)
1193
+ if module.bias is not None:
1194
+ module.bias.data.zero_()
1195
+ elif isinstance(module, nn.Embedding):
1196
+ module.weight.data.normal_(mean=0.0, std=std)
1197
+ if module.padding_idx is not None:
1198
+ module.weight.data[module.padding_idx].zero_()
1199
+
1200
+ # @classmethod
1201
+ # def override_vision_model_wrapper(cls, model, config, vision_model_name, vision_model_params, torch_dtype):
1202
+ # # this can be called via from_pretrained from a class w/ head or w/o head so we extract the beheaded model version
1203
+ # beheaded_model = model.model if hasattr(model, "model") else model
1204
+ # cls.override_vision_model(beheaded_model, vision_model_name, vision_model_params, torch_dtype)
1205
+ # beheaded_model.freeze_relevant_params(config)
1206
+
1207
+
1208
+ MISTRAL_INPUTS_DOCSTRING = r"""
1209
+ Args:
1210
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1211
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
1212
+ it.
1213
+
1214
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1215
+ [`PreTrainedTokenizer.__call__`] for details.
1216
+
1217
+ [What are input IDs?](../glossary#input-ids)
1218
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1219
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1220
+
1221
+ - 1 for tokens that are **not masked**,
1222
+ - 0 for tokens that are **masked**.
1223
+
1224
+ [What are attention masks?](../glossary#attention-mask)
1225
+
1226
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1227
+ [`PreTrainedTokenizer.__call__`] for details.
1228
+
1229
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
1230
+ `past_key_values`).
1231
+
1232
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
1233
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
1234
+ information on the default strategy.
1235
+
1236
+ - 1 indicates the head is **not masked**,
1237
+ - 0 indicates the head is **masked**.
1238
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1239
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
1240
+ config.n_positions - 1]`.
1241
+
1242
+ [What are position IDs?](../glossary#position-ids)
1243
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
1244
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
1245
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
1246
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
1247
+
1248
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
1249
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
1250
+
1251
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
1252
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
1253
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
1254
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1255
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
1256
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
1257
+ model's internal embedding lookup matrix.
1258
+ use_cache (`bool`, *optional*):
1259
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
1260
+ `past_key_values`).
1261
+ output_attentions (`bool`, *optional*):
1262
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1263
+ tensors for more detail.
1264
+ output_hidden_states (`bool`, *optional*):
1265
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1266
+ more detail.
1267
+ return_dict (`bool`, *optional*):
1268
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1269
+ """
1270
+
1271
+
1272
+ @add_start_docstrings(
1273
+ "The bare Mistral Model outputting raw hidden-states without any specific head on top.",
1274
+ MISTRAL_START_DOCSTRING,
1275
+ )
1276
+ class VMistralModel(VMistralPreTrainedModel):
1277
+ """
1278
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MistralDecoderLayer`]
1279
+
1280
+ Args:
1281
+ config: VMistralConfig
1282
+ """
1283
+
1284
+ def __init__(self, config: VMistralConfig, vision_model=None):
1285
+ super().__init__(config)
1286
+ self.config = config
1287
+ self.padding_idx = config.pad_token_id
1288
+ self.vocab_size = config.vocab_size
1289
+
1290
+ self.sliding_window = config.sliding_window
1291
+
1292
+ self.embed_tokens = DecoupledEmbedding(
1293
+ num_embeddings=config.vocab_size,
1294
+ num_additional_embeddings=config.additional_vocab_size,
1295
+ embedding_dim=config.hidden_size,
1296
+ partially_freeze=config.freeze_text_layers,
1297
+ padding_idx=self.padding_idx,
1298
+ )
1299
+
1300
+ # Load an uninitialized model and later in from_pretrained will load the pre-trained model -
1301
+ # this solves the losing of weights in `from_pretrained` on the main model
1302
+ self.vision_model = SiglipVisionModel(config.vision_config)
1303
+
1304
+ # Dim projection - projecting from the vision dim to the text dim
1305
+ self.modality_projection = ModalityProjection(
1306
+ embed_dim_in=self.config.vision_config.hidden_size, embed_dim_out=self.config.hidden_size
1307
+ )
1308
+
1309
+ # Perceiver Resampler
1310
+ if config.use_resampler:
1311
+ self.perceiver_resampler = PerceiverResampler(
1312
+ config.hidden_size,
1313
+ config.perceiver_config.resampler_depth,
1314
+ config.perceiver_config.resampler_n_heads,
1315
+ config.perceiver_config.resampler_head_dim,
1316
+ config.perceiver_config.resampler_n_latents,
1317
+ config.perceiver_config.qk_layer_norms_perceiver,
1318
+ )
1319
+
1320
+ if config.use_resampler:
1321
+ self.image_seq_len = config.perceiver_config.resampler_n_latents
1322
+ else:
1323
+ self.image_seq_len = (
1324
+ config.vision_config.image_size // config.vision_config.patch_size
1325
+ ) ** 2 # TODO: pretty sure that does not work for CLIP models since there is the CLS token
1326
+ self.image_token_id = self.config.image_token_id
1327
+
1328
+ self.layers = nn.ModuleList([MistralDecoderLayer(config) for _ in range(config.num_hidden_layers)])
1329
+
1330
+ self.gradient_checkpointing = False
1331
+
1332
+ self.norm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1333
+
1334
+ # Initialize weights and apply final processing
1335
+ self.post_init()
1336
+
1337
+ self.freeze_relevant_params(config)
1338
+
1339
+ def freeze_relevant_params(self, config=None):
1340
+ if config is None:
1341
+ config = self.config
1342
+
1343
+ if config.freeze_text_layers:
1344
+ self.freeze_text_layers(config.freeze_text_module_exceptions)
1345
+
1346
+ if config.freeze_vision_layers:
1347
+ freeze_model(self.vision_model, module_exceptions=config.freeze_vision_module_exceptions)
1348
+
1349
+ def freeze_text_layers(self, module_exceptions):
1350
+ for module in [self.layers, self.norm]:
1351
+ freeze_model(module, module_exceptions=module_exceptions)
1352
+
1353
+ def get_input_embeddings(self):
1354
+ return self.embed_tokens
1355
+
1356
+ def set_input_embeddings(self, value):
1357
+ self.embed_tokens = value
1358
+
1359
+ def inputs_merger(
1360
+ self,
1361
+ input_ids: torch.LongTensor = None,
1362
+ inputs_embeds: Optional[torch.Tensor] = None,
1363
+ image_hidden_states: Optional[torch.Tensor] = None,
1364
+ ):
1365
+ """
1366
+ This method aims at merging the token embeddings with the image hidden states into one single sequence of vectors that are fed to the transformer LM.
1367
+ The merging happens as follows:
1368
+ - The text token sequence is: `tok_1 tok_2 tok_3 <fake_token_around_image> <image> <image> ... <image> <fake_token_around_image> tok_4`.
1369
+ - We get the image hidden states for the image through the vision encoder (and potentially the perceiver), and that hidden state is then projected into the text embedding space.
1370
+ We thus have a sequence of image hidden states of size (1, image_seq_len, hidden_dim), where 1 is for batch_size of 1 image and hidden_dim is the hidden_dim of the LM transformer.
1371
+ - The merging happens so that we obtain the following sequence: `vector_tok_1 vector_tok_2 vector_tok_3 vector_fake_tok_around_image {sequence of image_seq_len image hidden states} vector_fake_toke_around_image vector_tok_4`. That sequence is fed to the LM.
1372
+ - To fit the format of that sequence, `input_ids`, `input_embeds`, `attention_mask` are all 3 adapted to insert the image hidden states.
1373
+ """
1374
+ batch_size = input_ids.size(0)
1375
+
1376
+ if inputs_embeds is not None:
1377
+ new_inputs_embeds = inputs_embeds.clone()
1378
+
1379
+ if image_hidden_states is not None:
1380
+ vision_pipeline_output_seq_len = image_hidden_states.shape[1]
1381
+ vision_hidden_size = image_hidden_states.shape[2]
1382
+ # Get the number of images for each example
1383
+ num_images = (input_ids == self.image_token_id).sum(dim=-1) // self.image_seq_len
1384
+ cum_num_images = num_images.cumsum(dim=-1)
1385
+ for batch_idx in range(batch_size):
1386
+ # Get the number of images for this particular example
1387
+ example_num_images = num_images[batch_idx]
1388
+ # Get the image_hidden_states corresponding to True images for the example, so get rid of the padding images.
1389
+ start = 0 if batch_idx == 0 else cum_num_images[batch_idx - 1]
1390
+ end = cum_num_images[batch_idx]
1391
+ example_true_image_hidden_states = image_hidden_states[start:end]
1392
+ if (
1393
+ new_inputs_embeds[batch_idx][input_ids[batch_idx] == self.image_token_id].shape[0]
1394
+ != example_num_images * vision_pipeline_output_seq_len
1395
+ ):
1396
+ raise ValueError(
1397
+ "new_inputs_embeds to replace has shape[0]:"
1398
+ f" {new_inputs_embeds[batch_idx][input_ids[batch_idx] == self.image_token_id].shape[0]} but"
1399
+ " should have shape[0]:"
1400
+ f" {example_num_images}*{vision_pipeline_output_seq_len}={example_num_images * vision_pipeline_output_seq_len} "
1401
+ )
1402
+ # Insert the image_hidden_states
1403
+ new_inputs_embeds[batch_idx][input_ids[batch_idx] == self.image_token_id] = (
1404
+ example_true_image_hidden_states.view(
1405
+ example_num_images * vision_pipeline_output_seq_len,
1406
+ vision_hidden_size,
1407
+ )
1408
+ )
1409
+
1410
+ return_dict = {}
1411
+ if inputs_embeds is not None:
1412
+ return_dict["inputs_embeds"] = new_inputs_embeds
1413
+
1414
+ return return_dict
1415
+
1416
+ @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
1417
+ def forward(
1418
+ self,
1419
+ input_ids: torch.LongTensor = None,
1420
+ attention_mask: Optional[torch.Tensor] = None,
1421
+ position_ids: Optional[torch.LongTensor] = None,
1422
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1423
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1424
+ pixel_values: Optional[torch.FloatTensor] = None,
1425
+ image_hidden_states: Optional[torch.FloatTensor] = None,
1426
+ use_cache: Optional[bool] = None,
1427
+ output_attentions: Optional[bool] = None,
1428
+ output_hidden_states: Optional[bool] = None,
1429
+ return_dict: Optional[bool] = None,
1430
+ ) -> Union[Tuple, VMistralBaseModelOutputWithPast]:
1431
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
1432
+
1433
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1434
+ output_hidden_states = (
1435
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1436
+ )
1437
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1438
+
1439
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1440
+
1441
+ # retrieve input_ids and inputs_embeds
1442
+ if input_ids is not None and inputs_embeds is not None:
1443
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
1444
+ elif input_ids is not None:
1445
+ batch_size, seq_length = input_ids.shape
1446
+ elif inputs_embeds is not None:
1447
+ batch_size, seq_length, _ = inputs_embeds.shape
1448
+ else:
1449
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
1450
+
1451
+ seq_length_with_past = seq_length
1452
+ past_key_values_length = 0
1453
+
1454
+ if past_key_values is not None:
1455
+ past_key_values_length = past_key_values[0][0].shape[2]
1456
+ seq_length_with_past = seq_length_with_past + past_key_values_length
1457
+
1458
+ if position_ids is None:
1459
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
1460
+ position_ids = torch.arange(
1461
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
1462
+ )
1463
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
1464
+ else:
1465
+ position_ids = position_ids.view(-1, seq_length).long()
1466
+
1467
+ if inputs_embeds is None:
1468
+ inputs_embeds = self.embed_tokens(input_ids)
1469
+
1470
+ # START VISUAL INPUTS INTEGRATION
1471
+ if pixel_values is not None and image_hidden_states is not None:
1472
+ raise ValueError("You cannot specify both pixel_values and image_hidden_states at the same time")
1473
+ elif pixel_values is not None:
1474
+ pixel_values = pixel_values.to(dtype=self.dtype, device=input_ids.device) # fp16 compatibility
1475
+ batch_size, num_images = pixel_values.size(0), pixel_values.size(1)
1476
+ pixel_values = pixel_values.contiguous().view(batch_size * num_images, *pixel_values.shape[2:])
1477
+ # Remove padding images - padding images are full 0.
1478
+ real_images_inds = pixel_values.sum(dim=(-1, -2, -3)) != 0.0
1479
+ pixel_values = pixel_values[real_images_inds]
1480
+ # Get sequence from the vision encoder
1481
+ image_hidden_states = self.vision_model(pixel_values=pixel_values).last_hidden_state
1482
+
1483
+ # Modality projection
1484
+ image_hidden_states = self.modality_projection(image_hidden_states)
1485
+
1486
+ if self.config.use_resampler:
1487
+ image_hidden_states = self.perceiver_resampler(image_hidden_states)
1488
+ elif image_hidden_states is not None:
1489
+ image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=input_ids.device)
1490
+
1491
+ if past_key_values is None:
1492
+ # When we generate, we don't want to replace the potential image_token_id that we generated by images
1493
+ # that simply don't exist
1494
+ new_inp = self.inputs_merger(
1495
+ input_ids=input_ids,
1496
+ inputs_embeds=inputs_embeds,
1497
+ image_hidden_states=image_hidden_states,
1498
+ )
1499
+ inputs_embeds = new_inp["inputs_embeds"]
1500
+
1501
+ # Can do add some token types embeddings here (image token vs text token)
1502
+ # something like inputs_embeds += self.token_types(token_types)
1503
+
1504
+ # embed positions
1505
+ if (
1506
+ attention_mask is not None
1507
+ and hasattr(self.config, "_flash_attn_2_enabled")
1508
+ and self.config._flash_attn_2_enabled
1509
+ and past_key_values is not None
1510
+ ):
1511
+ is_padding_right = attention_mask[:, -1].sum().item() != batch_size
1512
+ if is_padding_right:
1513
+ raise ValueError(
1514
+ "You are attempting to perform batched generation with padding_side='right'"
1515
+ " this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to "
1516
+ " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
1517
+ )
1518
+
1519
+ if getattr(self.config, "_flash_attn_2_enabled", False):
1520
+ # 2d mask is passed through the layers
1521
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
1522
+ else:
1523
+ # 4d mask is passed through the layers
1524
+ attention_mask = _prepare_4d_causal_attention_mask(
1525
+ attention_mask,
1526
+ (batch_size, seq_length),
1527
+ inputs_embeds,
1528
+ past_key_values_length,
1529
+ sliding_window=self.config.sliding_window,
1530
+ )
1531
+ attention_mask[attention_mask == -float("inf")] = torch.finfo(self.dtype).min
1532
+
1533
+ hidden_states = inputs_embeds
1534
+
1535
+ if self.gradient_checkpointing and self.training:
1536
+ if use_cache:
1537
+ logger.warning_once(
1538
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
1539
+ )
1540
+ use_cache = False
1541
+
1542
+ # decoder layers
1543
+ all_hidden_states = () if output_hidden_states else None
1544
+ all_self_attns = () if output_attentions else None
1545
+ next_decoder_cache = () if use_cache else None
1546
+
1547
+ for idx, decoder_layer in enumerate(self.layers):
1548
+ if output_hidden_states:
1549
+ all_hidden_states += (hidden_states,)
1550
+
1551
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
1552
+
1553
+ if self.gradient_checkpointing and self.training:
1554
+ layer_outputs = self._gradient_checkpointing_func(
1555
+ decoder_layer.__call__,
1556
+ hidden_states,
1557
+ attention_mask,
1558
+ position_ids,
1559
+ past_key_value,
1560
+ output_attentions,
1561
+ use_cache,
1562
+ )
1563
+ else:
1564
+ layer_outputs = decoder_layer(
1565
+ hidden_states,
1566
+ attention_mask=attention_mask,
1567
+ position_ids=position_ids,
1568
+ past_key_value=past_key_value,
1569
+ output_attentions=output_attentions,
1570
+ use_cache=use_cache,
1571
+ )
1572
+
1573
+ hidden_states = layer_outputs[0]
1574
+
1575
+ if use_cache:
1576
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
1577
+
1578
+ if output_attentions:
1579
+ all_self_attns += (layer_outputs[1],)
1580
+
1581
+ hidden_states = self.norm(hidden_states)
1582
+
1583
+ # add hidden states from the last decoder layer
1584
+ if output_hidden_states:
1585
+ all_hidden_states += (hidden_states,)
1586
+
1587
+ next_cache = next_decoder_cache if use_cache else None
1588
+ if not return_dict:
1589
+ return tuple(
1590
+ v
1591
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, image_hidden_states]
1592
+ if v is not None
1593
+ )
1594
+ return VMistralBaseModelOutputWithPast(
1595
+ last_hidden_state=hidden_states,
1596
+ past_key_values=next_cache,
1597
+ hidden_states=all_hidden_states,
1598
+ attentions=all_self_attns,
1599
+ image_hidden_states=image_hidden_states,
1600
+ )
1601
+
1602
+
1603
+ class VMistralForVisionText2Text(VMistralPreTrainedModel):
1604
+ _tied_weights_keys = ["lm_head.weight"]
1605
+
1606
+ def __init__(self, config, vision_model=None):
1607
+ super().__init__(config)
1608
+ self.model = VMistralModel(config, vision_model=vision_model)
1609
+ self.image_token_id = self.config.image_token_id
1610
+ self.lm_head = DecoupledLinear(
1611
+ in_features=config.hidden_size,
1612
+ out_features=config.vocab_size,
1613
+ out_additional_features=config.additional_vocab_size,
1614
+ bias=False,
1615
+ partially_freeze=config.freeze_lm_head,
1616
+ )
1617
+
1618
+ # Initialize weights and apply final processing
1619
+ self.post_init()
1620
+
1621
+ def get_input_embeddings(self):
1622
+ return self.model.embed_tokens
1623
+
1624
+ def set_input_embeddings(self, value):
1625
+ self.model.embed_tokens = value
1626
+
1627
+ def get_output_embeddings(self):
1628
+ return self.lm_head
1629
+
1630
+ def set_output_embeddings(self, new_embeddings):
1631
+ self.lm_head = new_embeddings
1632
+
1633
+ def set_decoder(self, decoder):
1634
+ self.model = decoder
1635
+
1636
+ def get_decoder(self):
1637
+ return self.model
1638
+
1639
+ def tie_weights(self):
1640
+ """
1641
+ Overwrite `transformers.modeling_utils.PreTrainedModel.tie_weights` to handle the case of DecoupledLinear and DecoupledEmbedding.
1642
+ """
1643
+ output_embeddings = self.get_output_embeddings()
1644
+ input_embeddings = self.get_input_embeddings()
1645
+
1646
+ if getattr(self.config, "tie_word_embeddings", True):
1647
+ output_embeddings.weight = input_embeddings.weight
1648
+ if input_embeddings.num_additional_embeddings > 0:
1649
+ assert output_embeddings.out_additional_features == input_embeddings.num_additional_embeddings
1650
+ output_embeddings.additional_fc.weight = input_embeddings.additional_embedding.weight
1651
+
1652
+ if hasattr(output_embeddings, "out_features") and hasattr(input_embeddings, "num_embeddings"):
1653
+ output_embeddings.out_features = input_embeddings.num_embeddings
1654
+ if hasattr(output_embeddings, "out_additional_features") and hasattr(
1655
+ input_embeddings, "num_additional_embeddings"
1656
+ ):
1657
+ output_embeddings.out_additional_features = input_embeddings.num_additional_embeddings
1658
+
1659
+ @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
1660
+ @replace_return_docstrings(output_type=VMistralCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1661
+ def forward(
1662
+ self,
1663
+ input_ids: torch.LongTensor = None,
1664
+ attention_mask: Optional[torch.Tensor] = None,
1665
+ position_ids: Optional[torch.LongTensor] = None,
1666
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1667
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1668
+ pixel_values: Optional[torch.FloatTensor] = None,
1669
+ image_hidden_states: Optional[torch.FloatTensor] = None,
1670
+ labels: Optional[torch.LongTensor] = None,
1671
+ use_cache: Optional[bool] = None,
1672
+ output_attentions: Optional[bool] = None,
1673
+ output_hidden_states: Optional[bool] = None,
1674
+ return_dict: Optional[bool] = None,
1675
+ ) -> Union[Tuple, VMistralCausalLMOutputWithPast]:
1676
+ r"""
1677
+ Args:
1678
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1679
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1680
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1681
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1682
+
1683
+ Returns:
1684
+
1685
+ """
1686
+
1687
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1688
+ output_hidden_states = (
1689
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1690
+ )
1691
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1692
+
1693
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1694
+ outputs = self.model(
1695
+ input_ids=input_ids,
1696
+ attention_mask=attention_mask,
1697
+ position_ids=position_ids,
1698
+ past_key_values=past_key_values,
1699
+ inputs_embeds=inputs_embeds,
1700
+ pixel_values=pixel_values,
1701
+ image_hidden_states=image_hidden_states,
1702
+ use_cache=use_cache,
1703
+ output_attentions=output_attentions,
1704
+ output_hidden_states=output_hidden_states,
1705
+ return_dict=return_dict,
1706
+ )
1707
+
1708
+ hidden_states = outputs[0]
1709
+ logits = self.lm_head(hidden_states)
1710
+ logits = logits.float()
1711
+
1712
+ loss = None
1713
+ if labels is not None:
1714
+ labels = labels.to(logits.device)
1715
+ # Shift so that tokens < n predict n
1716
+ if attention_mask is not None:
1717
+ shift_attention_mask = attention_mask[..., 1:].to(logits.device)
1718
+ shift_logits = logits[..., :-1, :][shift_attention_mask != 0].contiguous()
1719
+ shift_labels = labels[..., 1:][shift_attention_mask != 0].contiguous()
1720
+ else:
1721
+ shift_logits = logits[..., :-1, :].contiguous()
1722
+ shift_labels = labels[..., 1:].contiguous()
1723
+ # Flatten the tokens
1724
+ loss_fct = CrossEntropyLoss(ignore_index=self.image_token_id)
1725
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
1726
+
1727
+ if not return_dict:
1728
+ output = (logits,) + outputs[1:]
1729
+ return (loss,) + output if loss is not None else output
1730
+
1731
+ return VMistralCausalLMOutputWithPast(
1732
+ loss=loss,
1733
+ logits=logits,
1734
+ past_key_values=outputs.past_key_values,
1735
+ hidden_states=outputs.hidden_states,
1736
+ attentions=outputs.attentions,
1737
+ image_hidden_states=outputs.image_hidden_states,
1738
+ )
1739
+
1740
+ def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
1741
+ image_hidden_states = kwargs.pop("image_hidden_states", None)
1742
+ if image_hidden_states is not None:
1743
+ kwargs["pixel_values"] = None
1744
+ inputs = prepare_inputs_for_generation(input_ids, past=past, **kwargs)
1745
+ unwanted_kwargs = ["token_type_ids"]
1746
+ for kwarg in unwanted_kwargs:
1747
+ inputs.pop(kwarg, None)
1748
+ return inputs
1749
+
1750
+ @staticmethod
1751
+ def _expand_inputs_for_generation(
1752
+ *args,
1753
+ **model_kwargs,
1754
+ ):
1755
+ return expand_inputs_for_generation(*args, **model_kwargs)
1756
+
1757
+ @staticmethod
1758
+ def _update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder):
1759
+ return update_model_kwargs_for_generation(outputs, model_kwargs)
1760
+
1761
+ @staticmethod
1762
+ def _reorder_cache(past, beam_idx):
1763
+ reordered_past = ()
1764
+ for layer_past in past:
1765
+ reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
1766
+ return reordered_past
modeling_web.py ADDED
@@ -0,0 +1,681 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ import inspect
3
+ import warnings
4
+ from typing import List, Optional, Tuple, Union
5
+ import sys
6
+ import os
7
+ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
8
+ import torch
9
+ import torch.nn.functional as F
10
+ import torch.utils.checkpoint
11
+ from torch import nn
12
+ from torch.nn import CrossEntropyLoss
13
+ from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
14
+ from transformers.utils import (
15
+ is_flash_attn_2_available
16
+ )
17
+ from transformers import PreTrainedModel
18
+ from transformers.modeling_outputs import ModelOutput
19
+
20
+ from .configuration_vmistral import VMistralConfig
21
+ from .vision import SiglipVisionModel
22
+ from .modeling_vmistral import *
23
+ from .generation_utils import TreeBuilder, WebGenerationMixin
24
+ import time
25
+
26
+
27
+ if is_flash_attn_2_available():
28
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
29
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
30
+
31
+ _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
32
+
33
+ @dataclass
34
+ class WebLMOutputWithPast(ModelOutput):
35
+ loss: Optional[torch.FloatTensor] = None
36
+ logits: torch.FloatTensor = None
37
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
38
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
39
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
40
+ image_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
41
+ html_tree: TreeBuilder = None
42
+
43
+
44
+ class WebAttention(nn.Module):
45
+ """
46
+ Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
47
+ and "Generating Long Sequences with Sparse Transformers".
48
+ """
49
+
50
+ def __init__(self, config: VMistralConfig, qk_layer_norms: bool = False):
51
+ super().__init__()
52
+ self.config = config
53
+ self.hidden_size = config.hidden_size
54
+ self.num_heads = config.num_attention_heads
55
+ self.head_dim = self.hidden_size // self.num_heads
56
+ self.num_key_value_heads = config.num_key_value_heads
57
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
58
+ self.max_position_embeddings = config.max_position_embeddings
59
+ self.rope_theta = config.rope_theta
60
+ self.is_causal = True
61
+
62
+ if (self.head_dim * self.num_heads) != self.hidden_size:
63
+ raise ValueError(
64
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
65
+ f" and `num_heads`: {self.num_heads})."
66
+ )
67
+
68
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
69
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
70
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
71
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
72
+
73
+ self.qk_layer_norms = qk_layer_norms
74
+ if self.qk_layer_norms:
75
+ self.q_layer_norm = MistralRMSNorm(self.head_dim, eps=config.rms_norm_eps)
76
+ self.k_layer_norm = MistralRMSNorm(self.head_dim, eps=config.rms_norm_eps)
77
+
78
+ self.rotary_emb = MistralRotaryEmbedding(
79
+ self.head_dim,
80
+ max_position_embeddings=self.max_position_embeddings,
81
+ base=self.rope_theta,
82
+ )
83
+ self.attention_dropout = config.attention_dropout
84
+
85
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
86
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
87
+
88
+ def forward(
89
+ self,
90
+ hidden_states: torch.Tensor,
91
+ key_value_states: Optional[torch.Tensor] = None,
92
+ attention_mask: Optional[torch.Tensor] = None,
93
+ web_attention_mask: Optional[torch.Tensor] = None,
94
+ position_ids: Optional[torch.LongTensor] = None,
95
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
96
+ output_attentions: bool = False,
97
+ use_cache: bool = False,
98
+ **kwargs,
99
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
100
+ if "padding_mask" in kwargs:
101
+ warnings.warn(
102
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use"
103
+ " `attention_mask` instead.`"
104
+ )
105
+
106
+ bsz, q_len, _ = hidden_states.size()
107
+
108
+ query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
109
+ key_states = (
110
+ self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
111
+ )
112
+ value_states = (
113
+ self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
114
+ )
115
+
116
+ kv_seq_len = key_states.shape[-2]
117
+ if past_key_value is not None:
118
+ kv_seq_len += past_key_value[0].shape[-2]
119
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
120
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
121
+
122
+ if past_key_value is not None:
123
+ # reuse k, v, self_attention
124
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
125
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
126
+
127
+ past_key_value = (key_states, value_states) if use_cache else None
128
+
129
+ if self.qk_layer_norms:
130
+ query_states = self.q_layer_norm(query_states)
131
+ key_states = self.k_layer_norm(key_states)
132
+
133
+ # repeat k/v heads if n_kv_heads < n_heads
134
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
135
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
136
+ web_attention_range = self.config.web_attention_range
137
+
138
+ def split_tensor(tensor):
139
+ if int(web_attention_range) == 8:
140
+ return
141
+ fraction = float(web_attention_range) / 8
142
+ split_size_2 = int(self.num_heads * fraction)
143
+ split_size_1 = self.num_heads - split_size_2
144
+ return torch.split(tensor, [split_size_1, split_size_2], dim=1)
145
+
146
+ if int(web_attention_range) != 8:
147
+ query_states_1, query_states_2 = split_tensor(query_states)
148
+ key_states_1, key_states_2 = split_tensor(key_states)
149
+ value_states_1, value_states_2 = split_tensor(value_states)
150
+
151
+ with torch.backends.cuda.sdp_kernel(
152
+ enable_flash=False, enable_math=True, enable_mem_efficient=False
153
+ ):
154
+ attn_output_1 = F.scaled_dot_product_attention(query_states_1, key_states_1, value_states_1, attn_mask=attention_mask)
155
+
156
+ attn_output_2 = F.scaled_dot_product_attention(query_states_2, key_states_2, value_states_2, attn_mask=web_attention_mask)
157
+ attn_output = torch.cat([attn_output_1, attn_output_2], dim=1)
158
+ else:
159
+ with torch.backends.cuda.sdp_kernel(
160
+ enable_flash=False, enable_math=True, enable_mem_efficient=False
161
+ ):
162
+ attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, attention_mask=web_attention_mask)
163
+
164
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
165
+ raise ValueError(
166
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
167
+ f" {attn_output.size()}"
168
+ )
169
+
170
+ attn_output = attn_output.transpose(1, 2).contiguous()
171
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
172
+
173
+ attn_output = self.o_proj(attn_output)
174
+
175
+ if not output_attentions:
176
+ attn_weights = None
177
+
178
+ return attn_output, attn_weights, past_key_value
179
+
180
+
181
+ class WebFlashAttention2(WebAttention):
182
+ """
183
+ Mistral flash attention module. This module inherits from `MistralAttention` as the weights of the module stays
184
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
185
+ flash attention and deal with padding tokens in case the input contains any of them.
186
+ """
187
+
188
+ class WebDecoderLayer(nn.Module):
189
+ def __init__(self, config: VMistralConfig):
190
+ super().__init__()
191
+ self.hidden_size = config.hidden_size
192
+ self.self_attn = (
193
+ WebAttention(config=config)
194
+ if not getattr(config, "_flash_attn_2_enabled", False)
195
+ else WebFlashAttention2(config)
196
+ )
197
+ self.mlp = MistralMLP(config)
198
+ self.input_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
199
+ self.post_attention_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
200
+
201
+ def forward(
202
+ self,
203
+ hidden_states: torch.Tensor,
204
+ attention_mask: Optional[torch.Tensor] = None,
205
+ web_attention_mask: Optional[torch.Tensor] = None,
206
+ position_ids: Optional[torch.LongTensor] = None,
207
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
208
+ output_attentions: Optional[bool] = False,
209
+ use_cache: Optional[bool] = False,
210
+ **kwargs,
211
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
212
+ if "padding_mask" in kwargs:
213
+ warnings.warn(
214
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use"
215
+ " `attention_mask` instead.`"
216
+ )
217
+ """
218
+ Args:
219
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
220
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
221
+ `(batch, sequence_length)` where padding elements are indicated by 0.
222
+ output_attentions (`bool`, *optional*):
223
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
224
+ returned tensors for more detail.
225
+ use_cache (`bool`, *optional*):
226
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
227
+ (see `past_key_values`).
228
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
229
+ """
230
+
231
+ residual = hidden_states
232
+
233
+ hidden_states = self.input_layernorm(hidden_states)
234
+ # Self Attention
235
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
236
+ hidden_states=hidden_states,
237
+ attention_mask=attention_mask,
238
+ web_attention_mask=web_attention_mask,
239
+ position_ids=position_ids,
240
+ past_key_value=past_key_value,
241
+ output_attentions=output_attentions,
242
+ use_cache=use_cache,
243
+ )
244
+ hidden_states = residual + hidden_states
245
+
246
+ # Fully Connected
247
+ residual = hidden_states
248
+ hidden_states = self.post_attention_layernorm(hidden_states)
249
+ hidden_states = self.mlp(hidden_states)
250
+ hidden_states = residual + hidden_states
251
+
252
+ outputs = (hidden_states,)
253
+
254
+ if output_attentions:
255
+ outputs += (self_attn_weights,)
256
+
257
+ if use_cache:
258
+ outputs += (present_key_value,)
259
+
260
+ return outputs
261
+
262
+ class WebPreTrainedModel(PreTrainedModel):
263
+ config_class = VMistralConfig
264
+ base_model_prefix = "model"
265
+ supports_gradient_checkpointing = True
266
+ _no_split_modules = ["WebDecoderLayer"]
267
+ _skip_keys_device_placement = "past_key_values"
268
+ _supports_sdpa = False
269
+
270
+
271
+ class WebModel(WebPreTrainedModel, VMistralModel):
272
+ """
273
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MistralDecoderLayer`]
274
+
275
+ Args:
276
+ config: VMistralConfig
277
+ """
278
+
279
+ def __init__(self, config: VMistralConfig, vision_model=None):
280
+ super().__init__(config)
281
+ self.config = config
282
+ self.padding_idx = config.pad_token_id
283
+ self.vocab_size = config.vocab_size
284
+
285
+ self.sliding_window = config.sliding_window
286
+
287
+ self.embed_tokens = DecoupledEmbedding(
288
+ num_embeddings=config.vocab_size,
289
+ num_additional_embeddings=config.additional_vocab_size,
290
+ embedding_dim=config.hidden_size,
291
+ partially_freeze=config.freeze_text_layers,
292
+ padding_idx=self.padding_idx,
293
+ )
294
+
295
+ # Load an uninitialized model and later in from_pretrained will load the pre-trained model -
296
+ # this solves the losing of weights in `from_pretrained` on the main model
297
+ self.vision_model = SiglipVisionModel(config.vision_config)
298
+
299
+ # Dim projection - projecting from the vision dim to the text dim
300
+ self.modality_projection = ModalityProjection(
301
+ embed_dim_in=self.config.vision_config.hidden_size, embed_dim_out=self.config.hidden_size
302
+ )
303
+
304
+ # Perceiver Resampler
305
+ if config.use_resampler:
306
+ self.perceiver_resampler = PerceiverResampler(
307
+ config.hidden_size,
308
+ config.perceiver_config.resampler_depth,
309
+ config.perceiver_config.resampler_n_heads,
310
+ config.perceiver_config.resampler_head_dim,
311
+ config.perceiver_config.resampler_n_latents,
312
+ config.perceiver_config.qk_layer_norms_perceiver,
313
+ )
314
+
315
+ if config.use_resampler:
316
+ self.image_seq_len = config.perceiver_config.resampler_n_latents
317
+ else:
318
+ self.image_seq_len = (
319
+ config.vision_config.image_size // config.vision_config.patch_size
320
+ ) ** 2 # TODO: pretty sure that does not work for CLIP models since there is the CLS token
321
+ self.image_token_id = self.config.image_token_id
322
+
323
+ self.layers = nn.ModuleList([WebDecoderLayer(config) for _ in range(config.num_hidden_layers)])
324
+
325
+ self.gradient_checkpointing = False
326
+
327
+ self.norm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
328
+
329
+ # Initialize weights and apply final processing
330
+ self.post_init()
331
+
332
+ self.freeze_relevant_params(config)
333
+
334
+ def forward(
335
+ self,
336
+ input_ids: torch.LongTensor = None,
337
+ attention_mask: Optional[torch.Tensor] = None,
338
+ web_attention_mask: Optional[torch.Tensor] = None,
339
+ position_ids: Optional[torch.LongTensor] = None,
340
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
341
+ inputs_embeds: Optional[torch.FloatTensor] = None,
342
+ pixel_values: Optional[torch.FloatTensor] = None,
343
+ image_hidden_states: Optional[torch.FloatTensor] = None,
344
+ use_cache: Optional[bool] = None,
345
+ output_attentions: Optional[bool] = None,
346
+ output_hidden_states: Optional[bool] = None,
347
+ return_dict: Optional[bool] = None,
348
+ ) -> Union[Tuple, VMistralBaseModelOutputWithPast]:
349
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
350
+
351
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
352
+ output_hidden_states = (
353
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
354
+ )
355
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
356
+
357
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
358
+
359
+ # retrieve input_ids and inputs_embeds
360
+ if input_ids is not None and inputs_embeds is not None:
361
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
362
+ elif input_ids is not None:
363
+ batch_size, seq_length = input_ids.shape
364
+ elif inputs_embeds is not None:
365
+ batch_size, seq_length, _ = inputs_embeds.shape
366
+ else:
367
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
368
+
369
+ seq_length_with_past = seq_length
370
+ past_key_values_length = 0
371
+
372
+ if past_key_values is not None:
373
+ past_key_values_length = past_key_values[0][0].shape[2]
374
+ seq_length_with_past = seq_length_with_past + past_key_values_length
375
+
376
+ if position_ids is None:
377
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
378
+ position_ids = torch.arange(
379
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
380
+ )
381
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
382
+ else:
383
+ position_ids = position_ids.view(-1, seq_length).long()
384
+
385
+ if inputs_embeds is None:
386
+ inputs_embeds = self.embed_tokens(input_ids)
387
+
388
+ # START VISUAL INPUTS INTEGRATION
389
+ if pixel_values is not None and image_hidden_states is not None:
390
+ raise ValueError("You cannot specify both pixel_values and image_hidden_states at the same time")
391
+ elif pixel_values is not None:
392
+ pixel_values = pixel_values.to(dtype=self.dtype, device=input_ids.device) # fp16 compatibility
393
+ batch_size, num_images = pixel_values.size(0), pixel_values.size(1)
394
+
395
+ # this change allows multi image in a single batch
396
+ pixel_values = pixel_values.contiguous().view(batch_size, num_images, *pixel_values.shape[2:])
397
+ # # Remove padding images - padding images are full 0.
398
+ # real_images_inds = pixel_values.sum(dim=(-1, -2, -3)) != 0.0
399
+ # print(real_images_inds)
400
+ # pixel_values = pixel_values[real_images_inds]
401
+ # # Get sequence from the vision encoder
402
+ # print("shape_pixel", pixel_values.shape)
403
+ image_hidden_states = self.vision_model(pixel_values=pixel_values).last_hidden_state
404
+
405
+ # Modality projection
406
+ image_hidden_states = self.modality_projection(image_hidden_states)
407
+
408
+ if self.config.use_resampler:
409
+ image_hidden_states = self.perceiver_resampler(image_hidden_states)
410
+ elif image_hidden_states is not None:
411
+ image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=input_ids.device)
412
+
413
+ if past_key_values is None:
414
+ # When we generate, we don't want to replace the potential image_token_id that we generated by images
415
+ # that simply don't exist
416
+ new_inp = self.inputs_merger(
417
+ input_ids=input_ids,
418
+ inputs_embeds=inputs_embeds,
419
+ image_hidden_states=image_hidden_states,
420
+ )
421
+ inputs_embeds = new_inp["inputs_embeds"]
422
+
423
+ # Can do add some token types embeddings here (image token vs text token)
424
+ # something like inputs_embeds += self.token_types(token_types)
425
+
426
+ # embed positions
427
+ if (
428
+ attention_mask is not None
429
+ and hasattr(self.config, "_flash_attn_2_enabled")
430
+ and self.config._flash_attn_2_enabled
431
+ and past_key_values is not None
432
+ ):
433
+ is_padding_right = attention_mask[:, -1].sum().item() != batch_size
434
+ if is_padding_right:
435
+ raise ValueError(
436
+ "You are attempting to perform batched generation with padding_side='right'"
437
+ " this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to "
438
+ " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
439
+ )
440
+ # We did not implement our model using Flash attn 2
441
+ self.config._flash_attn_2_enabled = False
442
+ if not getattr(self.config, "_flash_attn_2_enabled", False):
443
+ # 2d mask is passed through the layers
444
+ # attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
445
+ attention_mask = _prepare_4d_causal_attention_mask(
446
+ attention_mask,
447
+ (batch_size, seq_length),
448
+ inputs_embeds,
449
+ past_key_values_length,
450
+ )
451
+ web_attention_mask = web_attention_mask.unsqueeze(1)
452
+ inverted_mask = 1.0 - web_attention_mask.to(inputs_embeds.dtype)
453
+ web_attention_mask = inverted_mask.masked_fill(
454
+ inverted_mask.to(torch.bool), -1.e32
455
+ )
456
+ if input_ids is not None:
457
+ bsz, L = input_ids.size()[:2]
458
+ web_attention_mask = web_attention_mask[:, :, -L:, :]
459
+ else:
460
+ print("Exiting, wrong branch")
461
+ exit()
462
+ # 4d mask is passed through the layers
463
+ attention_mask = _prepare_4d_causal_attention_mask(
464
+ attention_mask,
465
+ (batch_size, seq_length),
466
+ inputs_embeds,
467
+ past_key_values_length,
468
+ sliding_window=self.config.sliding_window,
469
+ )
470
+ attention_mask[attention_mask == -float("inf")] = torch.finfo(self.dtype).min
471
+
472
+ hidden_states = inputs_embeds
473
+
474
+ if self.gradient_checkpointing and self.training:
475
+ if use_cache:
476
+ logger.warning_once(
477
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
478
+ )
479
+ use_cache = False
480
+
481
+ # decoder layers
482
+ all_hidden_states = () if output_hidden_states else None
483
+ all_self_attns = () if output_attentions else None
484
+ next_decoder_cache = () if use_cache else None
485
+
486
+ for idx, decoder_layer in enumerate(self.layers):
487
+ if output_hidden_states:
488
+ all_hidden_states += (hidden_states,)
489
+
490
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
491
+
492
+ if self.gradient_checkpointing and self.training:
493
+ layer_outputs = self._gradient_checkpointing_func(
494
+ decoder_layer.__call__,
495
+ hidden_states,
496
+ attention_mask,
497
+ web_attention_mask,
498
+ position_ids,
499
+ past_key_value,
500
+ output_attentions,
501
+ use_cache,
502
+ )
503
+ else:
504
+ layer_outputs = decoder_layer(
505
+ hidden_states,
506
+ attention_mask=attention_mask,
507
+ web_attention_mask=web_attention_mask,
508
+ position_ids=position_ids,
509
+ past_key_value=past_key_value,
510
+ output_attentions=output_attentions,
511
+ use_cache=use_cache,
512
+ )
513
+
514
+ hidden_states = layer_outputs[0]
515
+
516
+ if use_cache:
517
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
518
+
519
+ if output_attentions:
520
+ all_self_attns += (layer_outputs[1],)
521
+
522
+ hidden_states = self.norm(hidden_states)
523
+
524
+ # add hidden states from the last decoder layer
525
+ if output_hidden_states:
526
+ all_hidden_states += (hidden_states,)
527
+
528
+ next_cache = next_decoder_cache if use_cache else None
529
+ if not return_dict:
530
+ return tuple(
531
+ v
532
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, image_hidden_states]
533
+ if v is not None
534
+ )
535
+ return VMistralBaseModelOutputWithPast(
536
+ last_hidden_state=hidden_states,
537
+ past_key_values=next_cache,
538
+ hidden_states=all_hidden_states,
539
+ attentions=all_self_attns,
540
+ image_hidden_states=image_hidden_states,
541
+ )
542
+
543
+ class WebForVisionText2Text(WebPreTrainedModel, WebGenerationMixin):
544
+ _tied_weights_keys = ["lm_head.weight"]
545
+
546
+ def __init__(self, config, vision_model=None):
547
+ super().__init__(config)
548
+ self.model = WebModel(config, vision_model=vision_model)
549
+ self.image_token_id = self.config.image_token_id
550
+ self.lm_head = DecoupledLinear(
551
+ in_features=config.hidden_size,
552
+ out_features=config.vocab_size,
553
+ out_additional_features=config.additional_vocab_size,
554
+ bias=False,
555
+ partially_freeze=config.freeze_lm_head,
556
+ )
557
+
558
+ # Initialize weights and apply final processing
559
+ self.post_init()
560
+
561
+ def forward(
562
+ self,
563
+ input_ids: torch.LongTensor = None,
564
+ attention_mask: Optional[torch.Tensor] = None,
565
+ web_attention_mask: Optional[torch.Tensor] = None,
566
+ position_ids: Optional[torch.LongTensor] = None,
567
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
568
+ inputs_embeds: Optional[torch.FloatTensor] = None,
569
+ pixel_values: Optional[torch.FloatTensor] = None,
570
+ image_hidden_states: Optional[torch.FloatTensor] = None,
571
+ labels: Optional[torch.LongTensor] = None,
572
+ use_cache: Optional[bool] = None,
573
+ output_attentions: Optional[bool] = None,
574
+ output_hidden_states: Optional[bool] = None,
575
+ return_dict: Optional[bool] = None,
576
+ html_tree = None,
577
+ ) -> Union[Tuple, WebLMOutputWithPast]:
578
+ r"""
579
+ Args:
580
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
581
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
582
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
583
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
584
+
585
+ Returns:
586
+
587
+ """
588
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
589
+ output_hidden_states = (
590
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
591
+ )
592
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
593
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
594
+ outputs = self.model(
595
+ input_ids=input_ids,
596
+ attention_mask=attention_mask,
597
+ web_attention_mask=web_attention_mask,
598
+ position_ids=position_ids,
599
+ past_key_values=past_key_values,
600
+ inputs_embeds=inputs_embeds,
601
+ pixel_values=pixel_values,
602
+ image_hidden_states=image_hidden_states,
603
+ use_cache=use_cache,
604
+ output_attentions=output_attentions,
605
+ output_hidden_states=output_hidden_states,
606
+ return_dict=return_dict,
607
+ )
608
+
609
+ hidden_states = outputs[0]
610
+ logits = self.lm_head(hidden_states)
611
+ logits = logits.float()
612
+
613
+ loss = None
614
+ if labels is not None:
615
+ labels = labels.to(logits.device)
616
+ # Shift so that tokens < n predict n
617
+ if attention_mask is not None:
618
+ shift_attention_mask = attention_mask[..., 1:].to(logits.device)
619
+ shift_logits = logits[..., :-1, :][shift_attention_mask != 0].contiguous()
620
+ shift_labels = labels[..., 1:][shift_attention_mask != 0].contiguous()
621
+ else:
622
+ shift_logits = logits[..., :-1, :].contiguous()
623
+ shift_labels = labels[..., 1:].contiguous()
624
+ # Flatten the tokens
625
+ loss_fct = CrossEntropyLoss(ignore_index=-100)
626
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
627
+
628
+ if not return_dict:
629
+ output = (logits,) + outputs[1:]
630
+ return (loss,) + output if loss is not None else output
631
+ # print(f"forward takes: {time.time()-start_time}")
632
+
633
+ return WebLMOutputWithPast(
634
+ loss=loss,
635
+ logits=logits,
636
+ past_key_values=outputs.past_key_values,
637
+ hidden_states=outputs.hidden_states,
638
+ attentions=outputs.attentions,
639
+ image_hidden_states=outputs.image_hidden_states,
640
+ html_tree = html_tree
641
+ )
642
+
643
+ def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs
644
+ ):
645
+ image_hidden_states = kwargs.pop("image_hidden_states", None)
646
+ if image_hidden_states is not None:
647
+ kwargs["pixel_values"] = None
648
+
649
+ inputs = prepare_inputs_for_generation(input_ids, past=past, **kwargs)
650
+ web_attention_mask, html_tree = None, kwargs.get("html_tree")
651
+
652
+ if html_tree.web_attention_mask is None :
653
+ attention_mask = inputs["attention_mask"]
654
+ web_attention_mask = torch.tril(torch.ones((attention_mask.shape[-1], attention_mask.shape[-1]), dtype = attention_mask.dtype)).unsqueeze(0)
655
+ html_tree.web_attention_mask = web_attention_mask
656
+ else:
657
+ html_tree = kwargs.get("html_tree")
658
+ input_ids = inputs["input_ids"]
659
+ tokenizer = html_tree.tokenizer
660
+ cur_decoded_token = tokenizer.convert_tokens_to_string([" "]+tokenizer.convert_ids_to_tokens(input_ids[:,-1]))
661
+ web_attn_range = html_tree.update_buffer([cur_decoded_token])
662
+ bsz, L = html_tree.web_attention_mask.size()[:2]
663
+ web_attention_mask = torch.zeros((bsz, L + 1, L + 1)).type_as(html_tree.web_attention_mask)
664
+ web_attention_mask[:, :L, :L] = html_tree.web_attention_mask
665
+ web_attn_range = torch.tensor(list(range(67))+[i + 67 for i in web_attn_range], dtype = web_attention_mask.dtype)
666
+ web_attention_mask[:, -1, web_attn_range] = 1
667
+ html_tree.web_attention_mask = web_attention_mask
668
+ if html_tree.input_ids is None :
669
+ html_tree.input_ids = input_ids
670
+ else:
671
+ html_tree.input_ids = torch.cat((html_tree.input_ids, input_ids), dim = 1)
672
+
673
+ unwanted_kwargs = ["token_type_ids"]
674
+ inputs.update({
675
+ "web_attention_mask": web_attention_mask.to(inputs['attention_mask'].device),
676
+ "html_tree": html_tree,
677
+ })
678
+ for kwarg in unwanted_kwargs:
679
+ inputs.pop(kwarg, None)
680
+
681
+ return inputs
preprocessor_config.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoImageProcessor": "IdeficsImageProcessor",
4
+ "AutoProcessor": "IdeficsProcessor"
5
+ },
6
+ "image_mean": [
7
+ 0.5,
8
+ 0.5,
9
+ 0.5
10
+ ],
11
+ "image_num_channels": 3,
12
+ "image_processor_type": "IdeficsImageProcessor",
13
+ "image_size": 960,
14
+ "image_std": [
15
+ 0.5,
16
+ 0.5,
17
+ 0.5
18
+ ],
19
+ "processor_class": "IdeficsProcessor"
20
+ }
special_tokens_map.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<s>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "</s>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": {
17
+ "content": "<unk>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "unk_token": {
24
+ "content": "<unk>",
25
+ "lstrip": false,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ }
30
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": true,
3
+ "add_eos_token": false,
4
+ "added_tokens_decoder": {
5
+ "0": {
6
+ "content": "<unk>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false,
11
+ "special": true
12
+ },
13
+ "1": {
14
+ "content": "<s>",
15
+ "lstrip": false,
16
+ "normalized": false,
17
+ "rstrip": false,
18
+ "single_word": false,
19
+ "special": true
20
+ },
21
+ "2": {
22
+ "content": "</s>",
23
+ "lstrip": false,
24
+ "normalized": false,
25
+ "rstrip": false,
26
+ "single_word": false,
27
+ "special": true
28
+ },
29
+ "32000": {
30
+ "content": "<fake_token_around_image>",
31
+ "lstrip": false,
32
+ "normalized": false,
33
+ "rstrip": false,
34
+ "single_word": false,
35
+ "special": true
36
+ },
37
+ "32001": {
38
+ "content": "<image>",
39
+ "lstrip": false,
40
+ "normalized": false,
41
+ "rstrip": false,
42
+ "single_word": false,
43
+ "special": true
44
+ }
45
+ },
46
+ "additional_special_tokens": [],
47
+ "bos_token": "<s>",
48
+ "clean_up_tokenization_spaces": false,
49
+ "eos_token": "</s>",
50
+ "legacy": false,
51
+ "model_max_length": 1000000000000000019884624838656,
52
+ "pad_token": "<unk>",
53
+ "processor_class": "IdeficsProcessor",
54
+ "sp_model_kwargs": {},
55
+ "spaces_between_special_tokens": false,
56
+ "tokenizer_class": "LlamaTokenizer",
57
+ "unk_token": "<unk>",
58
+ "use_default_system_prompt": true
59
+ }
vision.py ADDED
@@ -0,0 +1,653 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 Google AI and The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ A simplified copy of https://huggingface.co/HuggingFaceM4/siglip-so400m-14-384-flash-attn2 """
16
+
17
+
18
+ from dataclasses import dataclass
19
+ from typing import Any, Optional, Tuple, Union
20
+
21
+ import torch
22
+ import torch.nn.functional as F
23
+ import torch.utils.checkpoint
24
+ from torch import nn
25
+ from transformers.activations import ACT2FN
26
+ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
27
+ from transformers.utils import (
28
+ ModelOutput,
29
+ is_flash_attn_2_available,
30
+ logging,)
31
+
32
+ from .configuration_vmistral import VMistralVisionConfig
33
+
34
+
35
+ logger = logging.get_logger(__name__)
36
+
37
+
38
+ if is_flash_attn_2_available():
39
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
40
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
41
+
42
+
43
+ # Copied from transformers.models.llama.modeling_llama._get_unpad_data
44
+ def _get_unpad_data(attention_mask):
45
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
46
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
47
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
48
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
49
+ return (
50
+ indices,
51
+ cu_seqlens,
52
+ max_seqlen_in_batch,
53
+ )
54
+
55
+
56
+ @dataclass
57
+ # Copied from transformers.models.clip.modeling_clip.CLIPVisionModelOutput with CLIP->Siglip
58
+ class SiglipVisionModelOutput(ModelOutput):
59
+ """
60
+ Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.
61
+
62
+ Args:
63
+ image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
64
+ The image embeddings obtained by applying the projection layer to the pooler_output.
65
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
66
+ Sequence of hidden-states at the output of the last layer of the model.
67
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
68
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
69
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
70
+
71
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
72
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
73
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
74
+ sequence_length)`.
75
+
76
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
77
+ heads.
78
+ """
79
+
80
+ image_embeds: Optional[torch.FloatTensor] = None
81
+ last_hidden_state: torch.FloatTensor = None
82
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
83
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
84
+
85
+
86
+ class SiglipVisionEmbeddings(nn.Module):
87
+ def __init__(self, config: VMistralVisionConfig):
88
+ super().__init__()
89
+ self.config = config
90
+ self.embed_dim = config.hidden_size
91
+ self.image_size = config.image_size
92
+ self.patch_size = config.patch_size
93
+
94
+ self.patch_embedding = nn.Conv2d(
95
+ in_channels=config.num_channels,
96
+ out_channels=self.embed_dim,
97
+ kernel_size=self.patch_size,
98
+ stride=self.patch_size,
99
+ padding="valid",
100
+ )
101
+
102
+ self.num_patches = (self.image_size // self.patch_size) ** 2
103
+ self.num_positions = self.num_patches
104
+ self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
105
+ self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
106
+
107
+ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
108
+ # print(self.patch_embedding)
109
+ patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
110
+ embeddings = patch_embeds.flatten(2).transpose(1, 2)
111
+
112
+ embeddings = embeddings + self.position_embedding(self.position_ids)
113
+ return embeddings
114
+
115
+
116
+ # Copied from transformers.models.clip.modeling_clip.CLIPAttention with CLIP->Siglip
117
+ class SiglipAttention(nn.Module):
118
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
119
+
120
+ def __init__(self, config):
121
+ super().__init__()
122
+ self.config = config
123
+ self.embed_dim = config.hidden_size
124
+ self.num_heads = config.num_attention_heads
125
+ self.head_dim = self.embed_dim // self.num_heads
126
+ if self.head_dim * self.num_heads != self.embed_dim:
127
+ raise ValueError(
128
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
129
+ f" {self.num_heads})."
130
+ )
131
+ self.scale = self.head_dim**-0.5
132
+ self.dropout = config.attention_dropout
133
+
134
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
135
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
136
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
137
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
138
+
139
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
140
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
141
+
142
+ def forward(
143
+ self,
144
+ hidden_states: torch.Tensor,
145
+ attention_mask: Optional[torch.Tensor] = None,
146
+ causal_attention_mask: Optional[torch.Tensor] = None,
147
+ output_attentions: Optional[bool] = False,
148
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
149
+ """Input shape: Batch x Time x Channel"""
150
+
151
+ bsz, tgt_len, embed_dim = hidden_states.size()
152
+
153
+ # get query proj
154
+ query_states = self.q_proj(hidden_states) * self.scale
155
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
156
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
157
+
158
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
159
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
160
+ key_states = key_states.view(*proj_shape)
161
+ value_states = value_states.view(*proj_shape)
162
+
163
+ src_len = key_states.size(1)
164
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
165
+
166
+ if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
167
+ raise ValueError(
168
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
169
+ f" {attn_weights.size()}"
170
+ )
171
+
172
+ # apply the causal_attention_mask first
173
+ if causal_attention_mask is not None:
174
+ if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len):
175
+ raise ValueError(
176
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
177
+ f" {causal_attention_mask.size()}"
178
+ )
179
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask
180
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
181
+
182
+ if attention_mask is not None:
183
+ if attention_mask.size() != (bsz, 1, tgt_len, src_len):
184
+ raise ValueError(
185
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
186
+ )
187
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
188
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
189
+
190
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
191
+
192
+ if output_attentions:
193
+ # this operation is a bit akward, but it's required to
194
+ # make sure that attn_weights keeps its gradient.
195
+ # In order to do so, attn_weights have to reshaped
196
+ # twice and have to be reused in the following
197
+ attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
198
+ attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
199
+ else:
200
+ attn_weights_reshaped = None
201
+
202
+ attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
203
+
204
+ attn_output = torch.bmm(attn_probs, value_states)
205
+
206
+ if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
207
+ raise ValueError(
208
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
209
+ f" {attn_output.size()}"
210
+ )
211
+
212
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
213
+ attn_output = attn_output.transpose(1, 2)
214
+ attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
215
+
216
+ attn_output = self.out_proj(attn_output)
217
+
218
+ return attn_output, attn_weights_reshaped
219
+
220
+
221
+ class SiglipFlashAttention2(SiglipAttention):
222
+ """
223
+ Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays
224
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
225
+ flash attention and deal with padding tokens in case the input contains any of them.
226
+ """
227
+
228
+ def __init__(self, *args, **kwargs):
229
+ super().__init__(*args, **kwargs)
230
+ self.is_causal = False # Hack to make sure we don't use a causal mask
231
+
232
+ def forward(
233
+ self,
234
+ hidden_states: torch.Tensor,
235
+ attention_mask: Optional[torch.LongTensor] = None,
236
+ position_ids: Optional[torch.LongTensor] = None,
237
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
238
+ output_attentions: bool = False,
239
+ use_cache: bool = False,
240
+ **kwargs,
241
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
242
+ output_attentions = False
243
+
244
+ bsz, q_len, _ = hidden_states.size()
245
+
246
+ query_states = self.q_proj(hidden_states)
247
+ key_states = self.k_proj(hidden_states)
248
+ value_states = self.v_proj(hidden_states)
249
+
250
+ # Flash attention requires the input to have the shape
251
+ # batch_size x seq_length x head_dim x hidden_dim
252
+ # therefore we just need to keep the original shape
253
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
254
+ key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
255
+ value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
256
+
257
+ kv_seq_len = key_states.shape[-2]
258
+ if past_key_value is not None:
259
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
260
+ # cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
261
+ # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
262
+
263
+ # if past_key_value is not None:
264
+ # cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
265
+ # key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
266
+
267
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
268
+ # to be able to avoid many of these transpose/reshape/view.
269
+ query_states = query_states.transpose(1, 2)
270
+ key_states = key_states.transpose(1, 2)
271
+ value_states = value_states.transpose(1, 2)
272
+
273
+ dropout_rate = self.dropout if self.training else 0.0
274
+
275
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
276
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
277
+ # cast them back in the correct dtype just to be sure everything works as expected.
278
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
279
+ # in fp32. (LlamaRMSNorm handles it correctly)
280
+
281
+ input_dtype = query_states.dtype
282
+ if input_dtype == torch.float32:
283
+ if torch.is_autocast_enabled():
284
+ target_dtype = torch.get_autocast_gpu_dtype()
285
+ # Handle the case where the model is quantized
286
+ elif hasattr(self.config, "_pre_quantization_dtype"):
287
+ target_dtype = self.config._pre_quantization_dtype
288
+ else:
289
+ target_dtype = self.q_proj.weight.dtype
290
+
291
+ logger.warning_once(
292
+ "The input hidden states seems to be silently casted in float32, this might be related to the fact"
293
+ " you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
294
+ f" {target_dtype}."
295
+ )
296
+
297
+ query_states = query_states.to(target_dtype)
298
+ key_states = key_states.to(target_dtype)
299
+ value_states = value_states.to(target_dtype)
300
+
301
+ attn_output = self._flash_attention_forward(
302
+ query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
303
+ )
304
+
305
+ attn_output = attn_output.reshape(bsz, q_len, self.embed_dim).contiguous()
306
+ attn_output = self.out_proj(attn_output)
307
+
308
+ if not output_attentions:
309
+ attn_weights = None
310
+
311
+ return attn_output, attn_weights
312
+
313
+ def _flash_attention_forward(
314
+ self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
315
+ ):
316
+ """
317
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
318
+ first unpad the input, then computes the attention scores and pad the final attention scores.
319
+
320
+ Args:
321
+ query_states (`torch.Tensor`):
322
+ Input query states to be passed to Flash Attention API
323
+ key_states (`torch.Tensor`):
324
+ Input key states to be passed to Flash Attention API
325
+ value_states (`torch.Tensor`):
326
+ Input value states to be passed to Flash Attention API
327
+ attention_mask (`torch.Tensor`):
328
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
329
+ position of padding tokens and 1 for the position of non-padding tokens.
330
+ dropout (`int`, *optional*):
331
+ Attention dropout
332
+ softmax_scale (`float`, *optional*):
333
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
334
+ """
335
+
336
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
337
+ causal = self.is_causal and query_length != 1
338
+
339
+ # Contains at least one padding token in the sequence
340
+ if attention_mask is not None:
341
+ batch_size = query_states.shape[0]
342
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
343
+ query_states, key_states, value_states, attention_mask, query_length
344
+ )
345
+
346
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
347
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
348
+
349
+ attn_output_unpad = flash_attn_varlen_func(
350
+ query_states,
351
+ key_states,
352
+ value_states,
353
+ cu_seqlens_q=cu_seqlens_q,
354
+ cu_seqlens_k=cu_seqlens_k,
355
+ max_seqlen_q=max_seqlen_in_batch_q,
356
+ max_seqlen_k=max_seqlen_in_batch_k,
357
+ dropout_p=dropout,
358
+ softmax_scale=softmax_scale,
359
+ causal=causal,
360
+ )
361
+
362
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
363
+ else:
364
+ attn_output = flash_attn_func(
365
+ query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
366
+ )
367
+
368
+ return attn_output
369
+
370
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
371
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
372
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
373
+
374
+ key_layer = index_first_axis(
375
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
376
+ )
377
+ value_layer = index_first_axis(
378
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
379
+ )
380
+ if query_length == kv_seq_len:
381
+ query_layer = index_first_axis(
382
+ query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
383
+ )
384
+ cu_seqlens_q = cu_seqlens_k
385
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
386
+ indices_q = indices_k
387
+ elif query_length == 1:
388
+ max_seqlen_in_batch_q = 1
389
+ cu_seqlens_q = torch.arange(
390
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
391
+ ) # There is a memcpy here, that is very bad.
392
+ indices_q = cu_seqlens_q[:-1]
393
+ query_layer = query_layer.squeeze(1)
394
+ else:
395
+ # The -q_len: slice assumes left padding.
396
+ attention_mask = attention_mask[:, -query_length:]
397
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
398
+
399
+ return (
400
+ query_layer,
401
+ key_layer,
402
+ value_layer,
403
+ indices_q,
404
+ (cu_seqlens_q, cu_seqlens_k),
405
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
406
+ )
407
+
408
+
409
+ # Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Siglip
410
+ class SiglipMLP(nn.Module):
411
+ def __init__(self, config):
412
+ super().__init__()
413
+ self.config = config
414
+ self.activation_fn = ACT2FN[config.hidden_act]
415
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
416
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
417
+
418
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
419
+ hidden_states = self.fc1(hidden_states)
420
+ hidden_states = self.activation_fn(hidden_states)
421
+ hidden_states = self.fc2(hidden_states)
422
+ return hidden_states
423
+
424
+
425
+ # Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with CLIP->Siglip
426
+ class SiglipEncoderLayer(nn.Module):
427
+ def __init__(self, config: VMistralVisionConfig):
428
+ super().__init__()
429
+ self.embed_dim = config.hidden_size
430
+ self.self_attn = (
431
+ SiglipAttention(config)
432
+ # if not getattr(config, "_flash_attn_2_enabled", False)
433
+ # else SiglipFlashAttention2(config)
434
+ )
435
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
436
+ self.mlp = SiglipMLP(config)
437
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
438
+
439
+ def forward(
440
+ self,
441
+ hidden_states: torch.Tensor,
442
+ attention_mask: torch.Tensor,
443
+ causal_attention_mask: torch.Tensor,
444
+ output_attentions: Optional[bool] = False,
445
+ ) -> Tuple[torch.FloatTensor]:
446
+ """
447
+ Args:
448
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
449
+ attention_mask (`torch.FloatTensor`): attention mask of size
450
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
451
+ `(config.encoder_attention_heads,)`.
452
+ output_attentions (`bool`, *optional*):
453
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
454
+ returned tensors for more detail.
455
+ """
456
+ residual = hidden_states
457
+
458
+ hidden_states = self.layer_norm1(hidden_states)
459
+ hidden_states, attn_weights = self.self_attn(
460
+ hidden_states=hidden_states,
461
+ attention_mask=attention_mask,
462
+ causal_attention_mask=causal_attention_mask,
463
+ output_attentions=output_attentions,
464
+ )
465
+ hidden_states = residual + hidden_states
466
+
467
+ residual = hidden_states
468
+ hidden_states = self.layer_norm2(hidden_states)
469
+ hidden_states = self.mlp(hidden_states)
470
+ hidden_states = residual + hidden_states
471
+
472
+ outputs = (hidden_states,)
473
+
474
+ if output_attentions:
475
+ outputs += (attn_weights,)
476
+
477
+ return outputs
478
+
479
+
480
+ # Copied from transformers.models.clip.modeling_clip.CLIPEncoder with CLIP->Siglip
481
+ class SiglipEncoder(nn.Module):
482
+ """
483
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
484
+ [`SiglipEncoderLayer`].
485
+
486
+ Args:
487
+ config: SiglipConfig
488
+ """
489
+
490
+ def __init__(self, config):
491
+ super().__init__()
492
+ self.config = config
493
+ self.layers = nn.ModuleList([SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)])
494
+ self.gradient_checkpointing = False
495
+
496
+ def forward(
497
+ self,
498
+ inputs_embeds,
499
+ attention_mask: Optional[torch.Tensor] = None,
500
+ causal_attention_mask: Optional[torch.Tensor] = None,
501
+ output_attentions: Optional[bool] = None,
502
+ output_hidden_states: Optional[bool] = None,
503
+ return_dict: Optional[bool] = None,
504
+ ) -> Union[Tuple, BaseModelOutput]:
505
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
506
+ output_hidden_states = (
507
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
508
+ )
509
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
510
+
511
+ encoder_states = () if output_hidden_states else None
512
+ all_attentions = () if output_attentions else None
513
+
514
+ hidden_states = inputs_embeds
515
+ for idx, encoder_layer in enumerate(self.layers):
516
+ if output_hidden_states:
517
+ encoder_states = encoder_states + (hidden_states,)
518
+ if self.gradient_checkpointing and self.training:
519
+
520
+ def create_custom_forward(module):
521
+ def custom_forward(*inputs):
522
+ return module(*inputs, output_attentions)
523
+
524
+ return custom_forward
525
+
526
+ layer_outputs = torch.utils.checkpoint.checkpoint(
527
+ create_custom_forward(encoder_layer),
528
+ hidden_states,
529
+ attention_mask,
530
+ causal_attention_mask,
531
+ )
532
+ else:
533
+ layer_outputs = encoder_layer(
534
+ hidden_states,
535
+ attention_mask,
536
+ causal_attention_mask,
537
+ output_attentions=output_attentions,
538
+ )
539
+
540
+ hidden_states = layer_outputs[0]
541
+
542
+ if output_attentions:
543
+ all_attentions = all_attentions + (layer_outputs[1],)
544
+
545
+ if output_hidden_states:
546
+ encoder_states = encoder_states + (hidden_states,)
547
+
548
+ if not return_dict:
549
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
550
+ return BaseModelOutput(
551
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
552
+ )
553
+
554
+
555
+ class SiglipVisionTransformer(nn.Module):
556
+ def __init__(self, config: VMistralVisionConfig):
557
+ super().__init__()
558
+ self.config = config
559
+ embed_dim = config.hidden_size
560
+
561
+ self.embeddings = SiglipVisionEmbeddings(config)
562
+ self.encoder = SiglipEncoder(config)
563
+ self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
564
+ self.head = SiglipMultiheadAttentionPoolingHead(config)
565
+
566
+ def forward(
567
+ self,
568
+ pixel_values,
569
+ output_attentions: Optional[bool] = None,
570
+ output_hidden_states: Optional[bool] = None,
571
+ return_dict: Optional[bool] = None,
572
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
573
+ r"""
574
+ Returns:
575
+
576
+ """
577
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
578
+ output_hidden_states = (
579
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
580
+ )
581
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
582
+
583
+ hidden_states = self.embeddings(pixel_values)
584
+ # print("hidden_states", hidden_states.shape)
585
+ encoder_outputs = self.encoder(
586
+ inputs_embeds=hidden_states,
587
+ output_attentions=output_attentions,
588
+ output_hidden_states=output_hidden_states,
589
+ return_dict=return_dict,
590
+ )
591
+
592
+ last_hidden_state = encoder_outputs[0]
593
+ last_hidden_state = self.post_layernorm(last_hidden_state)
594
+
595
+ pooled_output = self.head(last_hidden_state)
596
+
597
+ if not return_dict:
598
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
599
+
600
+ return BaseModelOutputWithPooling(
601
+ last_hidden_state=last_hidden_state,
602
+ pooler_output=pooled_output,
603
+ hidden_states=encoder_outputs.hidden_states,
604
+ attentions=encoder_outputs.attentions,
605
+ )
606
+
607
+
608
+ class SiglipMultiheadAttentionPoolingHead(nn.Module):
609
+ """Multihead Attention Pooling."""
610
+
611
+ def __init__(self, config: VMistralVisionConfig):
612
+ super().__init__()
613
+
614
+ self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
615
+ self.attention = torch.nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, batch_first=True)
616
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
617
+ self.mlp = SiglipMLP(config)
618
+
619
+ def forward(self, hidden_state):
620
+ batch_size = hidden_state.shape[0]
621
+ probe = self.probe.repeat(batch_size, 1, 1)
622
+
623
+ hidden_state = self.attention(probe, hidden_state, hidden_state)[0]
624
+
625
+ residual = hidden_state
626
+ hidden_state = self.layernorm(hidden_state)
627
+ hidden_state = residual + self.mlp(hidden_state)
628
+
629
+ return hidden_state[:, 0]
630
+
631
+
632
+ class SiglipVisionModel(nn.Module):
633
+ def __init__(self, config: VMistralVisionConfig):
634
+ super().__init__()
635
+
636
+ self.config = config
637
+ self.vision_model = SiglipVisionTransformer(config)
638
+
639
+ def forward(
640
+ self,
641
+ pixel_values,
642
+ output_attentions: Optional[bool] = None,
643
+ output_hidden_states: Optional[bool] = None,
644
+ return_dict: Optional[bool] = None,
645
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
646
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
647
+
648
+ return self.vision_model(
649
+ pixel_values=pixel_values,
650
+ output_attentions=output_attentions,
651
+ output_hidden_states=output_hidden_states,
652
+ return_dict=return_dict,
653
+ )