XFFXFF commited on
Commit
340ef3d
·
1 Parent(s): 53f7487

remove unused python files

Browse files
Files changed (7) hide show
  1. configuration_aria.py +0 -114
  2. modeling_aria.py +0 -365
  3. moe_lm.py +0 -679
  4. processing_aria.py +0 -305
  5. projector.py +0 -189
  6. vision_encoder.py +0 -152
  7. vision_processor.py +0 -321
configuration_aria.py DELETED
@@ -1,114 +0,0 @@
1
- # Copyright 2024 Rhymes AI. All rights reserved.
2
- #
3
- # Licensed to the Apache Software Foundation (ASF) under one
4
- # or more contributor license agreements. See the NOTICE file
5
- # distributed with this work for additional information
6
- # regarding copyright ownership. The ASF licenses this file
7
- # to you under the Apache License, Version 2.0 (the
8
- # "License"); you may not use this file except in compliance
9
- # with the License. You may obtain a copy of the License at
10
- #
11
- # http://www.apache.org/licenses/LICENSE-2.0
12
- #
13
- # Unless required by applicable law or agreed to in writing,
14
- # software distributed under the License is distributed on an
15
- # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16
- # KIND, either express or implied. See the License for the
17
- # specific language governing permissions and limitations
18
- # under the License.
19
-
20
- import logging
21
-
22
- from transformers.configuration_utils import PretrainedConfig
23
-
24
- from .moe_lm import AriaMoELMConfig
25
- from .vision_encoder import AriaVisionConfig
26
-
27
- logger = logging.getLogger(__name__)
28
-
29
-
30
- # adapted from transformers.models.llava.configuration_llava.LlavaConfig
31
- class AriaConfig(PretrainedConfig):
32
- """
33
- Configuration class for Aria model.
34
-
35
- This class handles the configuration for both vision and text components of the Aria model,
36
- as well as additional parameters for image token handling and projector mapping.
37
-
38
- Args:
39
- vision_config (AriaVisionConfig or dict): Configuration for the vision component.
40
- text_config (AriaMoELMConfig or dict): Configuration for the text component.
41
- projector_patch_to_query_dict (dict): Mapping of patch sizes to query dimensions.
42
- ignore_index (int): Index to ignore in loss calculation.
43
- image_token_index (int): Index used to represent image tokens.
44
- **kwargs: Additional keyword arguments passed to the parent class.
45
-
46
- Attributes:
47
- model_type (str): Type of the model, set to "aria".
48
- is_composition (bool): Whether the model is a composition of multiple components.
49
- ignore_index (int): Index to ignore in loss calculation.
50
- image_token_index (int): Index used to represent image tokens.
51
- projector_patch_to_query_dict (dict): Mapping of patch sizes to query dimensions.
52
- vision_config (AriaVisionConfig): Configuration for the vision component.
53
- text_config (AriaMoELMConfig): Configuration for the text component.
54
- """
55
-
56
- model_type = "aria"
57
- is_composition = False
58
-
59
- def __init__(
60
- self,
61
- vision_config=AriaVisionConfig(),
62
- text_config=AriaMoELMConfig(),
63
- projector_patch_to_query_dict={
64
- 1225: 128,
65
- 4900: 256,
66
- },
67
- ignore_index=-100,
68
- image_token_index=32000,
69
- tie_word_embeddings=False,
70
- **kwargs,
71
- ):
72
- super().__init__(**kwargs)
73
- self.ignore_index = ignore_index
74
- self.image_token_index = image_token_index
75
- self.tie_word_embeddings = tie_word_embeddings
76
- attn_implementation = kwargs.pop("attn_implementation", None)
77
-
78
- # Set the default attention implementation to flash_attention_2 if not specified
79
- self._attn_implementation = (
80
- "flash_attention_2" if attn_implementation is None else attn_implementation
81
- )
82
-
83
- # Convert the keys and values of projector_patch_to_query_dict to integers
84
- # This ensures consistency even if they were provided as strings
85
- self.projector_patch_to_query_dict = {
86
- int(k): int(v) for k, v in projector_patch_to_query_dict.items()
87
- }
88
-
89
- if isinstance(vision_config, dict) and "model_type" in vision_config:
90
- vision_config = AriaVisionConfig(**vision_config)
91
- if attn_implementation is None:
92
- vision_attn_implementation = "flash_attention_2"
93
- elif attn_implementation == "sdpa":
94
- logger.warning(
95
- "SDPA is not supported for vit, using flash_attention_2 instead"
96
- )
97
- vision_attn_implementation = "flash_attention_2"
98
- else:
99
- vision_attn_implementation = attn_implementation
100
- vision_config._attn_implementation = vision_attn_implementation
101
-
102
- self.vision_config = vision_config
103
-
104
- if isinstance(text_config, dict) and "model_type" in text_config:
105
- text_attn_implementation = (
106
- "sdpa" if attn_implementation is None else attn_implementation
107
- )
108
- text_config = AriaMoELMConfig(**text_config)
109
- text_config._attn_implementation = text_attn_implementation
110
-
111
- self.text_config = text_config
112
-
113
- # This is needed for the static kv cache
114
- self.num_hidden_layers = self.text_config.num_hidden_layers
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modeling_aria.py DELETED
@@ -1,365 +0,0 @@
1
- # Copyright 2024 Rhymes AI. All rights reserved.
2
- #
3
- # Licensed to the Apache Software Foundation (ASF) under one
4
- # or more contributor license agreements. See the NOTICE file
5
- # distributed with this work for additional information
6
- # regarding copyright ownership. The ASF licenses this file
7
- # to you under the Apache License, Version 2.0 (the
8
- # "License"); you may not use this file except in compliance
9
- # with the License. You may obtain a copy of the License at
10
- #
11
- # http://www.apache.org/licenses/LICENSE-2.0
12
- #
13
- # Unless required by applicable law or agreed to in writing,
14
- # software distributed under the License is distributed on an
15
- # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16
- # KIND, either express or implied. See the License for the
17
- # specific language governing permissions and limitations
18
- # under the License.
19
-
20
- from dataclasses import dataclass
21
- from typing import List, Optional, Tuple, Union
22
-
23
- import torch
24
- import torch.nn as nn
25
- from torch import nn
26
- from transformers import GenerationMixin, PreTrainedModel
27
- from transformers.modeling_outputs import ModelOutput
28
- from transformers.utils import logging
29
-
30
- from .configuration_aria import AriaConfig
31
- from .moe_lm import AriaMoELMForCausalLM
32
- from .projector import AriaProjector
33
- from .vision_encoder import AriaVisionModel
34
-
35
- logger = logging.get_logger(__name__)
36
-
37
-
38
- class AriaPretrainedModel(PreTrainedModel):
39
- """
40
- An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models.
41
- """
42
-
43
- config_class = AriaConfig
44
- base_model_prefix = "model"
45
- _no_split_modules = []
46
- supports_gradient_checkpointing = True
47
- _skip_keys_device_placement = "past_key_values"
48
- _supports_flash_attn_2 = True
49
- _supports_cache_class = True
50
- _supports_static_cache = True
51
-
52
- @property
53
- def _supports_sdpa(self):
54
- """
55
- Retrieve language_model's attribute to check whether the model supports
56
- SDPA (Scaled Dot Product Attention) or not.
57
- """
58
- return self.language_model._supports_sdpa
59
-
60
-
61
- @dataclass
62
- # Copied from transformers.models.llava.modeling_llava.LlavaCausalLMOutputWithPast with Llava->Aria
63
- class AriaCausalLMOutputWithPast(ModelOutput):
64
- """
65
- Base class for Aria causal language model (or autoregressive) outputs.
66
-
67
- Args:
68
- loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
69
- Language modeling loss (for next-token prediction).
70
- logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
71
- Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
72
- past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
73
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
74
- `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
75
-
76
- Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
77
- `past_key_values` input) to speed up sequential decoding.
78
- hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
79
- Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
80
- one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
81
-
82
- Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
83
- attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
84
- Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
85
- sequence_length)`.
86
-
87
- Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
88
- heads.
89
- image_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
90
- Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images,
91
- sequence_length, hidden_size)`.
92
-
93
- image_hidden_states of the model produced by the vision encoder, and optionally by the perceiver
94
- """
95
-
96
- loss: Optional[torch.FloatTensor] = None
97
- logits: torch.FloatTensor = None
98
- past_key_values: Optional[List[torch.FloatTensor]] = None
99
- hidden_states: Optional[Tuple[torch.FloatTensor]] = None
100
- attentions: Optional[Tuple[torch.FloatTensor]] = None
101
- image_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
102
-
103
-
104
- def build_mm_projector(config: AriaConfig):
105
- """
106
- Builds and returns an AriaProjector instance based on the provided configuration.
107
-
108
- Args:
109
- config (AriaConfig): The configuration object containing necessary parameters.
110
-
111
- Returns:
112
- AriaProjector: An instance of the AriaProjector class.
113
- """
114
- return AriaProjector(
115
- patch_to_query_dict=config.projector_patch_to_query_dict,
116
- embed_dim=config.vision_config.hidden_size,
117
- num_heads=config.vision_config.num_attention_heads,
118
- kv_dim=config.vision_config.hidden_size,
119
- ff_dim=config.text_config.hidden_size,
120
- output_dim=config.text_config.hidden_size,
121
- )
122
-
123
-
124
- # adapted from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration
125
- class AriaForConditionalGeneration(AriaPretrainedModel, GenerationMixin):
126
- """
127
- Aria model for conditional generation tasks.
128
-
129
- This model combines a vision tower, a multi-modal projector, and a language model
130
- to perform tasks that involve both image and text inputs.
131
- """
132
-
133
- def __init__(self, config: AriaConfig):
134
- super().__init__(config)
135
-
136
- self.vision_tower = AriaVisionModel(config.vision_config)
137
- self.multi_modal_projector = build_mm_projector(config)
138
- self.vocab_size = config.text_config.vocab_size
139
- self.language_model = AriaMoELMForCausalLM(config.text_config)
140
- self.pad_token_id = (
141
- self.config.pad_token_id if self.config.pad_token_id is not None else -1
142
- )
143
- self.post_init()
144
-
145
- def freeze_vit(self):
146
- """Freeze the parameters of the vision tower."""
147
- for param in self.vision_tower.parameters():
148
- param.requires_grad = False
149
-
150
- def freeze_projector(self):
151
- """Freeze the parameters of the multi-modal projector."""
152
- for param in self.multi_modal_projector.parameters():
153
- param.requires_grad = False
154
-
155
- def freeze_llm(self):
156
- """Freeze the parameters of the language model."""
157
- for param in self.language_model.parameters():
158
- param.requires_grad = False
159
-
160
- def get_input_embeddings(self) -> nn.Module:
161
- """Retrieve the input embeddings from the language model."""
162
- return self.language_model.get_input_embeddings()
163
-
164
- def set_input_embeddings(self, value):
165
- """Set the input embeddings for the language model."""
166
- self.language_model.set_input_embeddings(value)
167
-
168
- def get_output_embeddings(self):
169
- """Retrieve the output embeddings from the language model."""
170
- return self.language_model.get_output_embeddings()
171
-
172
- def set_output_embeddings(self, value):
173
- """Set the output embeddings for the language model."""
174
- self.language_model.set_output_embeddings(value)
175
-
176
- def set_moe_z_loss_coeff(self, value):
177
- """
178
- Set the z-loss coefficient for Mixture of Experts (MoE) models.
179
-
180
- Args:
181
- value: The z-loss coefficient value to set.
182
- """
183
- self.language_model.set_z_loss_coeff(value)
184
-
185
- def set_moe_aux_loss_coeff(self, value):
186
- """
187
- Set the auxiliary loss coefficient for Mixture of Experts (MoE) models.
188
-
189
- Args:
190
- value: The auxiliary loss coefficient value to set.
191
- """
192
- self.language_model.set_aux_loss_coeff(value)
193
-
194
- def forward(
195
- self,
196
- input_ids: torch.LongTensor = None,
197
- pixel_values: torch.FloatTensor = None,
198
- pixel_mask: torch.LongTensor = None,
199
- attention_mask: Optional[torch.Tensor] = None,
200
- position_ids: Optional[torch.LongTensor] = None,
201
- past_key_values: Optional[List[torch.FloatTensor]] = None,
202
- inputs_embeds: Optional[torch.FloatTensor] = None,
203
- labels: Optional[torch.LongTensor] = None,
204
- use_cache: Optional[bool] = None,
205
- output_attentions: Optional[bool] = None,
206
- output_hidden_states: Optional[bool] = None,
207
- return_dict: Optional[bool] = None,
208
- cache_position: Optional[torch.LongTensor] = None,
209
- num_logits_to_keep: int = 0,
210
- ) -> Union[Tuple, AriaCausalLMOutputWithPast]:
211
- """
212
- Forward pass of the AriaForConditionalGeneration model.
213
-
214
- This method processes both text and image inputs, merges them if necessary,
215
- and generates output using the language model.
216
-
217
- Args:
218
- input_ids (torch.LongTensor, optional): Input token ids.
219
- pixel_values (torch.FloatTensor, optional): Pixel values of the images.
220
- pixel_mask (torch.LongTensor, optional): Mask for the pixel values.
221
- attention_mask (torch.Tensor, optional): Attention mask.
222
- position_ids (torch.LongTensor, optional): Position ids.
223
- past_key_values (List[torch.FloatTensor], optional): Past key values for efficient processing.
224
- inputs_embeds (torch.FloatTensor, optional): Input embeddings.
225
- labels (torch.LongTensor, optional): Labels for computing the language modeling loss.
226
- use_cache (bool, optional): Whether to use the model's cache mechanism.
227
- output_attentions (bool, optional): Whether to output attention weights.
228
- output_hidden_states (bool, optional): Whether to output hidden states.
229
- return_dict (bool, optional): Whether to return a ModelOutput object.
230
-
231
- Returns:
232
- Union[Tuple, AriaCausalLMOutputWithPast]: Model outputs.
233
- """
234
- output_attentions = (
235
- output_attentions
236
- if output_attentions is not None
237
- else self.config.output_attentions
238
- )
239
- output_hidden_states = (
240
- output_hidden_states
241
- if output_hidden_states is not None
242
- else self.config.output_hidden_states
243
- )
244
- return_dict = (
245
- return_dict if return_dict is not None else self.config.use_return_dict
246
- )
247
-
248
- if inputs_embeds is None:
249
- # 1. Extra the input embeddings
250
- inputs_embeds = self.get_input_embeddings()(input_ids)
251
-
252
- image_features = None
253
- if pixel_values is not None:
254
- image_outputs, image_attn_mask = self.vision_tower(
255
- pixel_values,
256
- pixel_mask=pixel_mask,
257
- )
258
-
259
- selected_image_feature = image_outputs.last_hidden_state
260
- image_features = self.multi_modal_projector(
261
- selected_image_feature, attn_mask=image_attn_mask
262
- )
263
-
264
- if image_features is not None:
265
- n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
266
- n_image_features = image_features.shape[0] * image_features.shape[1]
267
-
268
- if n_image_tokens != n_image_features:
269
- raise ValueError(
270
- f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
271
- )
272
- special_image_mask = (
273
- (input_ids == self.config.image_token_index)
274
- .unsqueeze(-1)
275
- .expand_as(inputs_embeds)
276
- .to(inputs_embeds.device)
277
- )
278
- image_features = image_features.to(
279
- inputs_embeds.device, inputs_embeds.dtype
280
- )
281
- inputs_embeds = inputs_embeds.masked_scatter(
282
- special_image_mask, image_features
283
- )
284
-
285
- outputs = self.language_model(
286
- attention_mask=attention_mask,
287
- position_ids=position_ids,
288
- past_key_values=past_key_values,
289
- inputs_embeds=inputs_embeds,
290
- use_cache=use_cache,
291
- output_attentions=output_attentions,
292
- output_hidden_states=output_hidden_states,
293
- return_dict=return_dict,
294
- cache_position=cache_position,
295
- num_logits_to_keep=num_logits_to_keep,
296
- )
297
-
298
- logits = outputs[0]
299
-
300
- loss = None
301
- if labels is not None:
302
- # Shift so that tokens < n predict n
303
- if attention_mask is not None:
304
- # we use the input attention mask to shift the logits and labels, because it is 2D.
305
- # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
306
- shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(
307
- logits.device
308
- )
309
- shift_logits = logits[..., :-1, :][
310
- shift_attention_mask.to(logits.device) != 0
311
- ].contiguous()
312
- shift_labels = labels[..., 1:][
313
- shift_attention_mask.to(labels.device) != 0
314
- ].contiguous()
315
- else:
316
- shift_logits = logits[..., :-1, :].contiguous()
317
- shift_labels = labels[..., 1:].contiguous()
318
- # Flatten the tokens
319
- loss_fct = nn.CrossEntropyLoss()
320
- loss = loss_fct(
321
- shift_logits.view(-1, shift_logits.size(-1)),
322
- shift_labels.view(-1).to(shift_logits.device),
323
- )
324
-
325
- if not return_dict:
326
- output = (logits,) + outputs[1:]
327
- return (loss,) + output if loss is not None else output
328
-
329
- return AriaCausalLMOutputWithPast(
330
- loss=loss,
331
- logits=logits,
332
- past_key_values=outputs.past_key_values,
333
- hidden_states=outputs.hidden_states,
334
- attentions=outputs.attentions,
335
- )
336
-
337
- def prepare_inputs_for_generation(
338
- self,
339
- input_ids,
340
- past_key_values=None,
341
- inputs_embeds=None,
342
- pixel_values=None,
343
- pixel_mask=None,
344
- attention_mask=None,
345
- cache_position=None,
346
- num_logits_to_keep=None,
347
- **kwargs,
348
- ):
349
- model_inputs = self.language_model.prepare_inputs_for_generation(
350
- input_ids,
351
- past_key_values=past_key_values,
352
- inputs_embeds=inputs_embeds,
353
- attention_mask=attention_mask,
354
- cache_position=cache_position,
355
- num_logits_to_keep=num_logits_to_keep,
356
- **kwargs,
357
- )
358
-
359
- if cache_position[0] == 0:
360
- # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
361
- # Otherwise we need pixel values to be passed to model
362
- model_inputs["pixel_values"] = pixel_values
363
- model_inputs["pixel_mask"] = pixel_mask
364
-
365
- return model_inputs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
moe_lm.py DELETED
@@ -1,679 +0,0 @@
1
- # Copyright 2024 Rhymes AI. All rights reserved.
2
- #
3
- # Licensed to the Apache Software Foundation (ASF) under one
4
- # or more contributor license agreements. See the NOTICE file
5
- # distributed with this work for additional information
6
- # regarding copyright ownership. The ASF licenses this file
7
- # to you under the Apache License, Version 2.0 (the
8
- # "License"); you may not use this file except in compliance
9
- # with the License. You may obtain a copy of the License at
10
- #
11
- # http://www.apache.org/licenses/LICENSE-2.0
12
- #
13
- # Unless required by applicable law or agreed to in writing,
14
- # software distributed under the License is distributed on an
15
- # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16
- # KIND, either express or implied. See the License for the
17
- # specific language governing permissions and limitations
18
- # under the License.
19
-
20
- import logging
21
- import os
22
- from typing import Tuple
23
-
24
- import torch
25
- import torch.nn as nn
26
- import torch.nn.functional as F
27
- from torch import nn
28
- from transformers import GenerationMixin, LlamaConfig
29
- from transformers.models.llama.modeling_llama import (
30
- ACT2FN,
31
- LLAMA_ATTENTION_CLASSES,
32
- LlamaDecoderLayer,
33
- LlamaForCausalLM,
34
- LlamaMLP,
35
- LlamaModel,
36
- LlamaRMSNorm,
37
- LlamaRotaryEmbedding,
38
- )
39
-
40
- logger = logging.getLogger(__name__)
41
-
42
-
43
- class AriaMoELMConfig(LlamaConfig):
44
- """
45
- Configuration class for AriaMoE language model.
46
-
47
- This class extends the LlamaConfig to include additional parameters specific to the Mixture of Experts (MoE) architecture.
48
- """
49
-
50
- model_type = "aria_moe_lm"
51
-
52
- def __init__(
53
- self,
54
- moe_intermediate_size: int = 4096,
55
- moe_num_experts: int = 8,
56
- moe_topk: int = 2,
57
- moe_z_loss_coeff: float = 1e-5,
58
- moe_aux_loss_coeff: float = 1e-3,
59
- moe_num_shared_experts: int = 2,
60
- **kwargs,
61
- ):
62
- """
63
- Initialize the AriaMoELMConfig.
64
-
65
- Args:
66
- moe_intermediate_size (int): The intermediate size for MoE layers. Default is 4096.
67
- moe_num_experts (int): The number of experts in the MoE layer. Default is 8.
68
- moe_topk (int): The number of top experts to route to for each token. Default is 2.
69
- moe_z_loss_coeff (float): The coefficient for the auxiliary z-loss. Default is 1e-5.
70
- moe_aux_loss_coeff (float): The coefficient for the auxiliary load balancing loss. Default is 1e-3.
71
- moe_num_shared_experts (int): The number of shared experts. Default is 2.
72
- **kwargs: Additional keyword arguments to be passed to the parent LlamaConfig.
73
- """
74
- super().__init__(**kwargs)
75
- self.moe_intermediate_size = moe_intermediate_size
76
- self.moe_num_experts = moe_num_experts
77
- self.moe_topk = moe_topk
78
- self.moe_z_loss_coeff = moe_z_loss_coeff
79
- self.moe_aux_loss_coeff = moe_aux_loss_coeff
80
- self.moe_num_shared_experts = moe_num_shared_experts
81
-
82
-
83
- # copied from https://github.com/NVIDIA/Megatron-LM/blob/54f1f78529cbc2b9cddad313e7f9d96ac0420a27/megatron/core/transformer/moe/moe_utils.py#L101-L142
84
- class MoEAuxLossAutoScaler(torch.autograd.Function):
85
- """An AutoScaler that compute and scales the grad for auxiliary loss."""
86
-
87
- main_loss_backward_scale: torch.Tensor = torch.tensor(1.0)
88
-
89
- @staticmethod
90
- def forward(ctx, output: torch.Tensor, aux_loss: torch.Tensor):
91
- """Preserve the aux_loss by storing it in the context to avoid garbage collection.
92
-
93
- Args:
94
- output (torch.Tensor): The output tensor.
95
- aux_loss (torch.Tensor): The auxiliary loss tensor.
96
-
97
- Returns:
98
- torch.Tensor: The output tensor.
99
- """
100
- ctx.save_for_backward(aux_loss)
101
- return output
102
-
103
- @staticmethod
104
- def backward(ctx, grad_output: torch.Tensor):
105
- """Compute and scale the gradient for auxiliary loss..
106
-
107
- Args:
108
- grad_output (torch.Tensor): The gradient of the output.
109
-
110
- Returns:
111
- Tuple[torch.Tensor, torch.Tensor]: The gradient of the output, scaled auxiliary loss gradient.
112
- """
113
- (aux_loss,) = ctx.saved_tensors
114
- aux_loss_backward_scale = MoEAuxLossAutoScaler.main_loss_backward_scale
115
- scaled_aux_loss_grad = torch.ones_like(aux_loss) * aux_loss_backward_scale
116
- return grad_output, scaled_aux_loss_grad
117
-
118
- @staticmethod
119
- def set_loss_scale(scale: torch.Tensor):
120
- """set the scale of the aux loss.
121
-
122
- Args:
123
- scale (torch.Tensor): The scale value to set. Please ensure that the scale passed in matches the scale of the main_loss.
124
- """
125
- MoEAuxLossAutoScaler.main_loss_backward_scale = scale
126
-
127
-
128
- def z_loss_func(logits, z_loss_coeff):
129
- """Encourages the router's logits to remain small to enhance stability.
130
- Please refer to the ST-MoE paper (https://arxiv.org/pdf/2202.08906.pdf) for details.
131
-
132
- Args:
133
- logits (torch.Tensor): The logits of the router.
134
-
135
- Returns:
136
- torch.Tensor: The logits after applying the z-loss.
137
- """
138
-
139
- z_loss = torch.mean(torch.square(torch.logsumexp(logits, dim=-1))) * z_loss_coeff
140
- return z_loss
141
-
142
-
143
- def switch_load_balancing_loss_func(
144
- probs: torch.Tensor,
145
- tokens_per_expert: torch.Tensor,
146
- topk: int,
147
- moe_aux_loss_coeff: float,
148
- ):
149
- """Calculate the auxiliary loss for better load balancing.
150
- Please refer to the Switch Transformer paper (https://arxiv.org/abs/2101.03961) for details.
151
-
152
- Args:
153
- probs (torch.Tensor): The softmax probs output by the router for each token. [num_tokens, num_experts]
154
- tokens_per_expert (torch.Tensor): The number of assigned tokens for each expert. [num_experts]
155
-
156
- Returns:
157
- torch.Tensor: The auxiliary loss for load balancing.
158
- """
159
- num_tokens = probs.shape[0] * topk
160
- num_experts = probs.shape[1]
161
-
162
- probs_mean_per_expert = probs.mean(dim=0)
163
- aux_loss = torch.sum(probs_mean_per_expert * tokens_per_expert) * (
164
- num_experts / num_tokens * moe_aux_loss_coeff
165
- )
166
- return aux_loss
167
-
168
-
169
- # adapted from https://github.com/NVIDIA/Megatron-LM/blob/54f1f78529cbc2b9cddad313e7f9d96ac0420a27/megatron/core/transformer/moe/router.py#L96-L304
170
- class TopKRouter(nn.Module):
171
- """
172
- Top-K Router for Mixture of Experts (MoE) models.
173
-
174
- This router determines which experts should process each token based on the top-k scoring experts.
175
- It also applies auxiliary losses to encourage load balancing among experts.
176
-
177
- Args:
178
- config (AriaMoELMConfig): Configuration object containing MoE-related parameters.
179
- """
180
-
181
- def __init__(self, config: AriaMoELMConfig):
182
- super().__init__()
183
- self.config = config
184
-
185
- self.weight = nn.Parameter(
186
- torch.empty((self.config.moe_num_experts, self.config.hidden_size))
187
- )
188
- # FIXME: initialize the weight
189
-
190
- def gating(self, input: torch.Tensor) -> torch.Tensor:
191
- """
192
- Compute the gating logits for each token-expert pair.
193
-
194
- Args:
195
- input (torch.Tensor): Input tensor of shape [batch_size * seq_len, hidden_size].
196
-
197
- Returns:
198
- torch.Tensor: Logits tensor of shape [batch_size * seq_len, num_experts].
199
- """
200
- logits = torch.nn.functional.linear(input, self.weight)
201
- return logits
202
-
203
- def apply_z_loss(self, logits: torch.Tensor) -> torch.Tensor:
204
- """
205
- Apply z-loss to encourage router logits to remain small for enhanced stability.
206
-
207
- Args:
208
- logits (torch.Tensor): Router logits.
209
-
210
- Returns:
211
- torch.Tensor: Logits with z-loss applied.
212
- """
213
- z_loss = z_loss_func(logits, self.config.moe_z_loss_coeff)
214
- logits = MoEAuxLossAutoScaler.apply(logits, z_loss)
215
- return logits
216
-
217
- def apply_aux_loss(
218
- self,
219
- logits: torch.Tensor,
220
- tokens_per_expert: torch.Tensor,
221
- activation: torch.Tensor,
222
- ) -> torch.Tensor:
223
- """
224
- Apply auxiliary loss for load balancing among experts.
225
-
226
- Args:
227
- logits (torch.Tensor): Router logits.
228
- tokens_per_expert (torch.Tensor): Number of tokens assigned to each expert.
229
- activation (torch.Tensor): Activation values.
230
-
231
- Returns:
232
- torch.Tensor: Activation with auxiliary loss applied.
233
- """
234
- probs = torch.softmax(logits, dim=-1, dtype=torch.float32)
235
- aux_loss = switch_load_balancing_loss_func(
236
- probs,
237
- tokens_per_expert,
238
- self.config.moe_topk,
239
- self.config.moe_aux_loss_coeff,
240
- )
241
- return MoEAuxLossAutoScaler.apply(activation, aux_loss)
242
-
243
- def routing(
244
- self, logits: torch.Tensor
245
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
246
- """
247
- Perform the routing operation to determine expert assignments.
248
-
249
- Args:
250
- logits (torch.Tensor): Router logits.
251
-
252
- Returns:
253
- Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
254
- - scores: Softmax probabilities for top-k experts.
255
- - top_indices: Indices of top-k experts for each token.
256
- - tokens_per_expert: Number of tokens assigned to each expert.
257
- """
258
- if self.training:
259
- logits = self.apply_z_loss(logits)
260
-
261
- top_logits, top_indices = torch.topk(logits, k=self.config.moe_topk, dim=1)
262
- scores = torch.softmax(top_logits, dim=-1, dtype=torch.float32).type_as(logits)
263
-
264
- tokens_per_expert = torch.histc(
265
- top_indices.flatten(),
266
- bins=self.config.moe_num_experts,
267
- min=0,
268
- max=self.config.moe_num_experts - 1,
269
- )
270
-
271
- if self.training:
272
- scores = self.apply_aux_loss(logits, tokens_per_expert, scores)
273
- return scores, top_indices, tokens_per_expert
274
-
275
- def forward(
276
- self, input: torch.Tensor
277
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
278
- """
279
- Forward pass of the TopKRouter.
280
-
281
- Args:
282
- input (torch.Tensor): Input tensor of shape [batch_size * seq_len, hidden_size].
283
-
284
- Returns:
285
- Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
286
- - scores: Softmax probabilities for top-k experts.
287
- - top_indices: Indices of top-k experts for each token.
288
- - tokens_per_expert: Number of tokens assigned to each expert.
289
- """
290
- logits = self.gating(input)
291
- logits = logits.view(-1, self.config.moe_num_experts)
292
- scores, top_indices, tokens_per_expert = self.routing(logits)
293
- return scores, top_indices, tokens_per_expert
294
-
295
-
296
- # adapted from https://github.com/NVIDIA/Megatron-LM/blob/54f1f78529cbc2b9cddad313e7f9d96ac0420a27/megatron/core/transformer/moe/token_dispatcher.py#L291-L587
297
- class TokenDispatcher:
298
- """
299
- Handles the dispatching and gathering of tokens to and from experts.
300
-
301
- This class is responsible for permuting tokens based on expert assignments and
302
- unpermuting them after expert processing.
303
-
304
- Args:
305
- config (AriaMoELMConfig): Configuration object containing MoE-related parameters.
306
- """
307
-
308
- def __init__(self, config: AriaMoELMConfig):
309
- self.config = config
310
- self.hidden_states_shape = None
311
- self.reversed_input_permutation_mapping = None
312
-
313
- def token_permutation(
314
- self, hidden_states: torch.Tensor, indices: torch.Tensor
315
- ) -> torch.Tensor:
316
- """
317
- Permute tokens based on expert assignments.
318
-
319
- Args:
320
- hidden_states (torch.Tensor): Input hidden states.
321
- indices (torch.Tensor): Expert assignment indices.
322
-
323
- Returns:
324
- torch.Tensor: Permuted tokens.
325
- """
326
- self.hidden_states_shape = hidden_states.shape
327
- hidden_states = hidden_states.view(-1, hidden_states.size(-1))
328
- flatten_indices = indices.flatten()
329
- sorted_indices = torch.argsort(flatten_indices, stable=True)
330
- permuted_tokens = hidden_states.index_select(
331
- 0, sorted_indices // self.config.moe_topk
332
- )
333
- self.reversed_input_permutation_mapping = sorted_indices
334
- return permuted_tokens
335
-
336
- def token_unpermutation(
337
- self, permuted_tokens: torch.Tensor, scores: torch.Tensor
338
- ) -> torch.Tensor:
339
- """
340
- Unpermute tokens and combine expert outputs.
341
-
342
- Args:
343
- permuted_tokens (torch.Tensor): Tokens after expert processing.
344
- scores (torch.Tensor): Expert assignment scores.
345
-
346
- Returns:
347
- torch.Tensor: Unpermuted and combined output.
348
- """
349
- num_unpermuted_tokens = scores.numel()
350
- unpermuted_tokens = torch.zeros(
351
- (num_unpermuted_tokens, permuted_tokens.size(1)),
352
- dtype=permuted_tokens.dtype,
353
- device=permuted_tokens.device,
354
- )
355
- unpermuted_tokens.index_copy_(
356
- 0, self.reversed_input_permutation_mapping, permuted_tokens
357
- )
358
- unpermuted_tokens = unpermuted_tokens.reshape(
359
- -1, self.config.moe_topk, permuted_tokens.size(1)
360
- )
361
-
362
- unpermuted_tokens = unpermuted_tokens * scores.unsqueeze(-1)
363
- unpermuted_tokens = unpermuted_tokens.sum(dim=1).type_as(permuted_tokens)
364
- output = unpermuted_tokens.view(self.hidden_states_shape)
365
- return output
366
-
367
-
368
- class SharedExpertMLP(LlamaMLP):
369
- """
370
- Shared Expert MLP for shared experts.
371
-
372
- Unlike routed experts, shared experts process all tokens without routing.
373
- This class reconfigures the intermediate size in comparison to the LlamaMLP.
374
-
375
- Args:
376
- config (AriaMoELMConfig): Configuration object for the AriaMoE language model.
377
- """
378
-
379
- def __init__(self, config: AriaMoELMConfig):
380
- nn.Module.__init__(self)
381
- self.config = config
382
- self.hidden_size = config.hidden_size
383
- self.intermediate_size = (
384
- config.moe_intermediate_size * config.moe_num_shared_experts
385
- )
386
- self.gate_proj = nn.Linear(
387
- self.hidden_size, self.intermediate_size, bias=config.mlp_bias
388
- )
389
- self.up_proj = nn.Linear(
390
- self.hidden_size, self.intermediate_size, bias=config.mlp_bias
391
- )
392
- self.down_proj = nn.Linear(
393
- self.intermediate_size, self.hidden_size, bias=config.mlp_bias
394
- )
395
- self.act_fn = ACT2FN[config.hidden_act]
396
-
397
-
398
- def sequential_gemm(input, weight, tokens_per_expert):
399
- """
400
- Compute the matrix multiplication (GEMM) for each expert sequentially. This approach is computationally inefficient, especially when dealing with a large number of experts.
401
-
402
- Args:
403
- input (torch.Tensor): Input tensor of shape (num_tokens, in_features).
404
- weight (torch.Tensor): Weight tensor of shape (num_experts, in_features, out_features).
405
- tokens_per_expert (torch.Tensor): Number of tokens assigned to each expert.
406
-
407
- Returns:
408
- torch.Tensor: Output tensor of shape (num_tokens, out_features).
409
- """
410
- num_tokens = input.shape[0]
411
- out_features = weight.shape[-1]
412
- output = torch.zeros(
413
- num_tokens, out_features, dtype=input.dtype, device=input.device
414
- )
415
-
416
- cumsum_num_tokens = torch.cumsum(tokens_per_expert, dim=0)
417
- # Insert zero at the begining for offset index's convenience
418
- zero_tensor = torch.zeros(1, dtype=torch.long, device=cumsum_num_tokens.device)
419
- cumsum_num_tokens = torch.cat((zero_tensor, cumsum_num_tokens))
420
-
421
- for expert_num in range(weight.shape[0]):
422
- start = cumsum_num_tokens[expert_num]
423
- end = cumsum_num_tokens[expert_num + 1]
424
- tokens = input[start:end]
425
-
426
- out = torch.matmul(tokens, weight[expert_num])
427
- output[start:end] = out
428
- return output
429
-
430
-
431
- try:
432
- from grouped_gemm.ops import gmm as experts_gemm
433
-
434
- if os.environ.get("USE_GROUPED_GEMM", "1") == "0":
435
- logger.warning(
436
- "environment variable USE_GROUPED_GEMM is set to 0, using sequential GEMM instead."
437
- )
438
- experts_gemm = sequential_gemm
439
- except ImportError:
440
- logger.warning(
441
- "`grouped_gemm` is not installed, using sequential GEMM, which is slower."
442
- )
443
- experts_gemm = sequential_gemm
444
-
445
-
446
- class GroupedGEMM(nn.Module):
447
- """
448
- Grouped GEMM (General Matrix Multiplication) module for efficient expert computation.
449
- This module utilizes the grouped_gemm library (https://github.com/fanshiqing/grouped_gemm)
450
- for optimized performance. If the grouped_gemm library is not installed, it gracefully
451
- falls back to a sequential GEMM implementation, which may be slower but ensures
452
- functionality.
453
-
454
- Args:
455
- in_features (int): Number of input features.
456
- out_features (int): Number of output features.
457
- groups (int): Number of expert groups.
458
- """
459
-
460
- def __init__(self, in_features, out_features, groups):
461
- super().__init__()
462
- self.in_features = in_features
463
- self.out_features = out_features
464
- self.groups = groups
465
- self.weight = nn.Parameter(torch.empty(groups, in_features, out_features))
466
-
467
- def forward(self, input, tokens_per_expert):
468
- """
469
- Perform grouped matrix multiplication.
470
-
471
- Args:
472
- input (torch.Tensor): Input tensor of shape (num_tokens, in_features).
473
- tokens_per_expert (torch.Tensor): Number of tokens assigned to each expert.
474
-
475
- Returns:
476
- torch.Tensor: Output tensor of shape (num_tokens, out_features).
477
- """
478
- tokens_per_expert = tokens_per_expert.cpu()
479
-
480
- # Ensure the CUDA device matches the input tensor's device.
481
- # This mismatch can occur when using `transformers.AutoModel.from_pretrained`
482
- # with `device_map="auto"` on a multi-GPU setup.
483
- torch.cuda.set_device(input.device)
484
- return experts_gemm(input, self.weight, tokens_per_expert)
485
-
486
-
487
- class GroupedMLP(nn.Module):
488
- """
489
- Grouped MLP module for Mixture of Experts.
490
-
491
- Args:
492
- config (AriaMoELMConfig): Configuration object for the model.
493
- """
494
-
495
- def __init__(self, config: AriaMoELMConfig) -> None:
496
- super().__init__()
497
- self.config = config
498
- self.fc1 = GroupedGEMM(
499
- config.hidden_size, config.moe_intermediate_size * 2, config.moe_num_experts
500
- )
501
- self.fc2 = GroupedGEMM(
502
- config.moe_intermediate_size, config.hidden_size, config.moe_num_experts
503
- )
504
-
505
- def glu(x):
506
- x = torch.chunk(x, 2, dim=-1)
507
- return F.silu(x[0]) * x[1]
508
-
509
- self.activation_func = glu
510
-
511
- def forward(self, permuted_tokens, tokens_per_expert):
512
- """
513
- Forward pass of the Grouped MLP.
514
-
515
- Args:
516
- permuted_tokens (torch.Tensor): Permuted input tokens.
517
- tokens_per_expert (torch.Tensor): Number of tokens assigned to each expert.
518
-
519
- Returns:
520
- torch.Tensor: Output tensor after passing through the MLP.
521
- """
522
- fc1_output = self.fc1(permuted_tokens, tokens_per_expert)
523
- fc1_output = self.activation_func(fc1_output)
524
- fc2_output = self.fc2(fc1_output, tokens_per_expert)
525
- return fc2_output
526
-
527
-
528
- class MoELayer(nn.Module):
529
- """
530
- Mixture of Experts (MoE) Layer for the AriaMoE model.
531
-
532
- This layer implements the MoE mechanism, which routes input tokens to different experts
533
- based on a routing algorithm, processes them through the experts, and then combines
534
- the outputs.
535
-
536
- Args:
537
- config (AriaMoELMConfig): Configuration object for the MoE layer.
538
- """
539
-
540
- def __init__(self, config: AriaMoELMConfig):
541
- super().__init__()
542
-
543
- self.router = TopKRouter(config)
544
- self.token_dispatcher = TokenDispatcher(config)
545
- self.experts = GroupedMLP(config)
546
- self.shared_experts = SharedExpertMLP(config)
547
-
548
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
549
- """
550
- Forward pass of the MoE Layer.
551
-
552
- Args:
553
- hidden_states (torch.Tensor): Input tensor of shape (batch_size, sequence_length, hidden_size).
554
-
555
- Returns:
556
- torch.Tensor: Output tensor after passing through the MoE layer.
557
-
558
- Process:
559
- 1. Route tokens to experts using the router.
560
- 2. Permute tokens based on routing decisions.
561
- 3. Process tokens through experts.
562
- 4. Unpermute and combine expert outputs.
563
- 5. Add shared expert output to the final result.
564
- """
565
- scores, indices, tokens_per_expert = self.router(hidden_states)
566
-
567
- permuted_tokens = self.token_dispatcher.token_permutation(
568
- hidden_states, indices
569
- )
570
-
571
- expert_output = self.experts(permuted_tokens, tokens_per_expert)
572
-
573
- output = self.token_dispatcher.token_unpermutation(expert_output, scores)
574
-
575
- shared_expert_output = self.shared_experts(hidden_states)
576
- output += shared_expert_output
577
- return output
578
-
579
-
580
- class MoEDecoderLayer(LlamaDecoderLayer):
581
- """
582
- Custom Decoder Layer for the AriaMoE model which modifies the standard `LlamaDecoderLayer` by
583
- replacing the traditional MLP with a Mixture of Experts (MoE) Layer.
584
-
585
- Args:
586
- config (LlamaConfig): Configuration object for the layer.
587
- layer_idx (int): Index of the current layer in the model.
588
- """
589
-
590
- def __init__(self, config: LlamaConfig, layer_idx: int):
591
- nn.Module.__init__(self)
592
- self.hidden_size = config.hidden_size
593
-
594
- self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](
595
- config=config, layer_idx=layer_idx
596
- )
597
-
598
- self.mlp = MoELayer(config)
599
- self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
600
- self.post_attention_layernorm = LlamaRMSNorm(
601
- config.hidden_size, eps=config.rms_norm_eps
602
- )
603
-
604
-
605
- class AriaMoELMModel(LlamaModel):
606
- """
607
- Custom LlamaModel for the AriaMoE model which modifies the standard LlamaModel by
608
- replacing the `LlamaDecoderLayer` with `MoEDecoderLayer`.
609
-
610
- This model implements a Mixture of Experts (MoE) approach, where each layer contains
611
- multiple expert networks that specialize in different aspects of the input.
612
-
613
- Args:
614
- config (LlamaConfig): Configuration object for the model.
615
- """
616
-
617
- def __init__(self, config: LlamaConfig):
618
- super().__init__(config)
619
- self.padding_idx = config.pad_token_id
620
- self.vocab_size = config.vocab_size
621
-
622
- self.embed_tokens = nn.Embedding(
623
- config.vocab_size, config.hidden_size, self.padding_idx
624
- )
625
- self.layers = nn.ModuleList(
626
- [
627
- MoEDecoderLayer(config, layer_idx)
628
- for layer_idx in range(config.num_hidden_layers)
629
- ]
630
- )
631
- self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
632
- self.rotary_emb = LlamaRotaryEmbedding(config=config)
633
- self.gradient_checkpointing = False
634
-
635
- # Initialize weights and apply final processing
636
- self.post_init()
637
-
638
-
639
- class AriaMoELMForCausalLM(LlamaForCausalLM, GenerationMixin):
640
- """
641
- AriaMoE model for causal language modeling tasks.
642
-
643
- This class extends LlamaForCausalLM to incorporate the Mixture of Experts (MoE) approach,
644
- allowing for more efficient and scalable language modeling.
645
-
646
- Args:
647
- config (AriaMoELMConfig): Configuration object for the model.
648
- """
649
-
650
- _tied_weights_keys = ["lm_head.weight"]
651
- config_class = AriaMoELMConfig
652
- _no_split_modules = ["MoEDecoderLayer"]
653
-
654
- def __init__(self, config):
655
- super().__init__(config)
656
- self.model = AriaMoELMModel(config)
657
- self.vocab_size = config.vocab_size
658
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
659
-
660
- # Initialize weights and apply final processing
661
- self.post_init()
662
-
663
- def set_z_loss_coeff(self, z_loss_coeff: float):
664
- """
665
- Set the coefficient for the z-loss in the MoE routing.
666
-
667
- Args:
668
- z_loss_coeff (float): The coefficient for the z-loss.
669
- """
670
- self.config.moe_z_loss_coeff = z_loss_coeff
671
-
672
- def set_aux_loss_coeff(self, aux_loss_coeff: float):
673
- """
674
- Set the coefficient for the auxiliary loss in the MoE routing.
675
-
676
- Args:
677
- aux_loss_coeff (float): The coefficient for the auxiliary loss.
678
- """
679
- self.config.moe_aux_loss_coeff = aux_loss_coeff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
processing_aria.py DELETED
@@ -1,305 +0,0 @@
1
- # Copyright 2024 Rhymes AI. All rights reserved.
2
- #
3
- # Licensed to the Apache Software Foundation (ASF) under one
4
- # or more contributor license agreements. See the NOTICE file
5
- # distributed with this work for additional information
6
- # regarding copyright ownership. The ASF licenses this file
7
- # to you under the Apache License, Version 2.0 (the
8
- # "License"); you may not use this file except in compliance
9
- # with the License. You may obtain a copy of the License at
10
- #
11
- # http://www.apache.org/licenses/LICENSE-2.0
12
- #
13
- # Unless required by applicable law or agreed to in writing,
14
- # software distributed under the License is distributed on an
15
- # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16
- # KIND, either express or implied. See the License for the
17
- # specific language governing permissions and limitations
18
- # under the License.
19
-
20
- import inspect
21
- import logging
22
- import re
23
- from typing import List, Optional, Union
24
-
25
- from transformers import AutoTokenizer, BatchFeature
26
- from transformers.image_utils import ImageInput
27
- from transformers.processing_utils import ProcessorMixin
28
- from transformers.tokenization_utils import (
29
- PaddingStrategy,
30
- PreTokenizedInput,
31
- TensorType,
32
- TextInput,
33
- TruncationStrategy,
34
- )
35
-
36
- from .vision_processor import AriaVisionProcessor
37
-
38
- logger = logging.getLogger(__name__)
39
-
40
-
41
- class AriaProcessor(ProcessorMixin):
42
- """
43
- AriaProcessor is a processor for the Aria model which wraps the Aria image preprocessor and the LLama slow tokenizer.
44
- Args:
45
- image_processor(AriaVisionProcessor): The AriaVisionProcessor to use for image preprocessing.
46
- tokenizer(AutoTokenizer): The AutoTokenizer to use for tokenizing the text.
47
- patch_size(int): The patch size to use for the image processor.
48
- chat_template(str): The chat template to use for the tokenizer.
49
- image_token(str): The image token to use for the tokenizer.
50
- """
51
-
52
- attributes = []
53
- valid_kwargs = ["chat_template", "patch_size", "image_token"]
54
- image_processor_class = None
55
- tokenizer_class = "AutoTokenizer"
56
-
57
- def __init__(
58
- self,
59
- image_processor: AriaVisionProcessor = None,
60
- tokenizer: Union[AutoTokenizer, str] = None,
61
- patch_size: int = 490,
62
- chat_template: str = None,
63
- image_token: str = "<|img|>",
64
- ):
65
- super().__init__(chat_template=chat_template)
66
-
67
- if image_processor is None:
68
- self.image_processor = AriaVisionProcessor(max_image_size=patch_size)
69
- else:
70
- self.image_processor = image_processor
71
-
72
- if isinstance(tokenizer, str):
73
- self.tokenizer = AutoTokenizer.from_pretrained(
74
- tokenizer, trust_remote_code=True, use_fast=False
75
- )
76
- else:
77
- self.tokenizer = tokenizer
78
-
79
- if self.tokenizer is not None and self.tokenizer.pad_token is None:
80
- self.tokenizer.pad_token = self.tokenizer.unk_token
81
-
82
- self.image_token = image_token
83
-
84
- # Copied from transformers.models.llava_next.processing_llave_next.LlavaNextProcessor.__call__
85
- def __call__(
86
- self,
87
- text: Union[
88
- TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]
89
- ],
90
- images: ImageInput = None,
91
- padding: Union[bool, str, PaddingStrategy] = False,
92
- truncation: Union[bool, str, TruncationStrategy] = None,
93
- max_length: Optional[int] = None,
94
- max_image_size: Optional[int] = 980,
95
- split_image: Optional[bool] = False,
96
- return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
97
- return_final_prompts: Optional[bool] = False,
98
- ) -> BatchFeature:
99
- """
100
- Main method to prepare for the model one or several sequences(s) and image(s). Please refer to the doctsring
101
- of the above two methods for more information.
102
-
103
- Args:
104
- text (`str`, `List[str]`, `List[List[str]]`):
105
- The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
106
- (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
107
- `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
108
- images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
109
- The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
110
- tensor. Both channels-first and channels-last formats are supported.
111
- padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):
112
- Select a strategy to pad the returned sequences (according to the model's padding side and padding
113
- index) among:
114
- - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
115
- sequence if provided).
116
- - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
117
- acceptable input length for the model if that argument is not provided.
118
- - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
119
- lengths).
120
- max_length (`int`, *optional*):
121
- Maximum length of the returned list and optionally padding length (see above).
122
- max_image_size (`int`, *optional*):
123
- Maximum size of the image to be processed.
124
- split_image (`bool`, *optional*):
125
- Whether to split the image into patches before processing.
126
- truncation (`bool`, *optional*):
127
- Activates truncation to cut input sequences longer than `max_length` to `max_length`.
128
- return_tensors (`str` or [`~utils.TensorType`], *optional*):
129
- If set, will return tensors of a particular framework. Acceptable values are:
130
-
131
- - `'tf'`: Return TensorFlow `tf.constant` objects.
132
- - `'pt'`: Return PyTorch `torch.Tensor` objects.
133
- - `'np'`: Return NumPy `np.ndarray` objects.
134
- - `'jax'`: Return JAX `jnp.ndarray` objects.
135
-
136
- Returns:
137
- [`BatchFeature`]: A [`BatchFeature`] with the following fields:
138
-
139
- - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
140
- - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
141
- `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
142
- `None`).
143
- - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
144
- - **pixel_mask** -- Pixel mask to be fed to a model. Returned when `images` is not `None`.
145
- """
146
- if isinstance(text, str):
147
- text = [text]
148
- elif not isinstance(text, list) and not isinstance(text[0], str):
149
- raise ValueError(
150
- "Invalid input text. Please provide a string, or a list of strings"
151
- )
152
-
153
- if images is not None:
154
- image_inputs = self.image_processor(
155
- images,
156
- return_tensors=return_tensors,
157
- max_image_size=max_image_size,
158
- split_image=split_image,
159
- )
160
- # expand the image_token according to the num_crops of image
161
- prompt_strings = []
162
- crop_iter = iter(image_inputs.pop("num_crops"))
163
- for prompt in text:
164
- prompt_strings.append(
165
- re.sub(
166
- re.escape(self.image_token),
167
- lambda _: next(crop_iter) * self.image_token,
168
- prompt,
169
- )
170
- )
171
-
172
- max_image_size = (
173
- max_image_size
174
- if max_image_size is not None
175
- else self.image_processor.max_image_size
176
- )
177
- if max_image_size == 490:
178
- num_image_tokens = 128
179
- elif max_image_size == 980:
180
- num_image_tokens = 256
181
- else:
182
- raise ValueError(
183
- f"max_image_size must be either 490 or 980, got {max_image_size}"
184
- )
185
- prompt_strings = [
186
- sample.replace(self.image_token, self.image_token * num_image_tokens)
187
- for sample in prompt_strings
188
- ]
189
-
190
- else:
191
- image_inputs = {}
192
- prompt_strings = text
193
-
194
- text_inputs = self.tokenizer(
195
- prompt_strings,
196
- return_tensors=return_tensors,
197
- padding=padding,
198
- truncation=truncation,
199
- max_length=max_length,
200
- )
201
-
202
- if return_final_prompts:
203
- return BatchFeature(data={**text_inputs, **image_inputs}), prompt_strings
204
- else:
205
- return BatchFeature(data={**text_inputs, **image_inputs})
206
-
207
- @staticmethod
208
- def _extract_kwargs(func: callable, **kwargs) -> dict:
209
- """
210
- Extract the kwargs that are valid for the given function.
211
- """
212
- return {
213
- k: v for k, v in kwargs.items() if k in inspect.signature(func).parameters
214
- }
215
-
216
- def save_pretrained(self, save_directory, **kwargs):
217
- """
218
- Save both the image processor and tokenizer.
219
- """
220
- if self.image_processor is not None:
221
- self.image_processor.save_pretrained(
222
- save_directory,
223
- **self._extract_kwargs(self.image_processor.save_pretrained, **kwargs),
224
- )
225
- if self.tokenizer is not None:
226
- self.tokenizer.save_pretrained(
227
- save_directory,
228
- **self._extract_kwargs(self.tokenizer.save_pretrained, **kwargs),
229
- )
230
-
231
- @classmethod
232
- def from_pretrained(
233
- cls,
234
- pretrained_model_name_or_path,
235
- tokenizer_path=None,
236
- image_processor_path=None,
237
- **kwargs,
238
- ):
239
- """
240
- Load both the image processor and tokenizer from a pretrained model path.
241
- """
242
- tokenizer_path = (
243
- tokenizer_path
244
- if tokenizer_path is not None
245
- else pretrained_model_name_or_path
246
- )
247
- image_processor_path = (
248
- image_processor_path
249
- if image_processor_path is not None
250
- else pretrained_model_name_or_path
251
- )
252
- image_processor = AriaVisionProcessor.from_pretrained(
253
- image_processor_path,
254
- **cls._extract_kwargs(AriaVisionProcessor.from_pretrained, **kwargs),
255
- )
256
- if "use_fast" in kwargs:
257
- logger.warning("use_fast is not supported for AriaProcessor. Ignoring...")
258
- kwargs.pop("use_fast")
259
- try:
260
- tokenizer = AutoTokenizer.from_pretrained(
261
- tokenizer_path,
262
- use_fast=False,
263
- **cls._extract_kwargs(AutoTokenizer.from_pretrained, **kwargs),
264
- )
265
- chat_template = tokenizer.chat_template
266
- except Exception as e:
267
- logger.warning(f"Failed to load tokenizer from {tokenizer_path}: {e}")
268
- tokenizer = None
269
- chat_template = None
270
- return cls(
271
- image_processor=image_processor,
272
- tokenizer=tokenizer,
273
- chat_template=chat_template,
274
- )
275
-
276
- # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama
277
- def batch_decode(self, *args, **kwargs):
278
- """
279
- This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
280
- refer to the docstring of this method for more information.
281
- """
282
- if self.tokenizer is None:
283
- raise ValueError(
284
- "Tokenizer is not initialized. Please provide a valid tokenizer."
285
- )
286
- return self.tokenizer.batch_decode(*args, **kwargs)
287
-
288
- # Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Llama
289
- def decode(self, *args, **kwargs):
290
- """
291
- This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
292
- the docstring of this method for more information.
293
- """
294
- if self.tokenizer is None:
295
- raise ValueError(
296
- "Tokenizer is not initialized. Please provide a valid tokenizer."
297
- )
298
- return self.tokenizer.decode(*args, **kwargs)
299
-
300
- @property
301
- # Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names
302
- def model_input_names(self):
303
- tokenizer_input_names = self.tokenizer.model_input_names
304
- image_processor_input_names = self.image_processor.model_input_names
305
- return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
projector.py DELETED
@@ -1,189 +0,0 @@
1
- # Copyright 2024 Rhymes AI. All rights reserved.
2
- #
3
- # Licensed to the Apache Software Foundation (ASF) under one
4
- # or more contributor license agreements. See the NOTICE file
5
- # distributed with this work for additional information
6
- # regarding copyright ownership. The ASF licenses this file
7
- # to you under the Apache License, Version 2.0 (the
8
- # "License"); you may not use this file except in compliance
9
- # with the License. You may obtain a copy of the License at
10
- #
11
- # http://www.apache.org/licenses/LICENSE-2.0
12
- #
13
- # Unless required by applicable law or agreed to in writing,
14
- # software distributed under the License is distributed on an
15
- # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16
- # KIND, either express or implied. See the License for the
17
- # specific language governing permissions and limitations
18
- # under the License.
19
-
20
- import torch
21
- import torch.nn as nn
22
- from torch.nn.init import trunc_normal_
23
- from transformers.activations import ACT2FN
24
-
25
-
26
- class FFN(nn.Module):
27
- """
28
- Feed-Forward Network module.
29
-
30
- Args:
31
- embed_dim (int): Input embedding dimension.
32
- ff_dim (int): Hidden dimension of the feed-forward network.
33
- output_dim (int): Output dimension.
34
- """
35
-
36
- def __init__(self, embed_dim, ff_dim, output_dim):
37
- super().__init__()
38
- self.linear_in = nn.Linear(embed_dim, ff_dim, bias=False)
39
- self.linear_out = nn.Linear(ff_dim, output_dim, bias=False)
40
- self.act = ACT2FN["gelu_new"]
41
-
42
- def forward(self, hidden_states):
43
- hidden_states = self.act(self.linear_in(hidden_states))
44
- hidden_states = self.linear_out(hidden_states)
45
- return hidden_states
46
-
47
-
48
- class CrossAttention(nn.Module):
49
- """
50
- Cross-Attention module.
51
-
52
- Args:
53
- kv_dim (int): Dimension of key and value.
54
- embed_dim (int): Embedding dimension.
55
- num_heads (int): Number of attention heads.
56
- drop_out_rate (float): Dropout rate. Default is 0.
57
- """
58
-
59
- def __init__(self, kv_dim, embed_dim, num_heads, drop_out_rate=0):
60
- super().__init__()
61
- self.num_heads = num_heads
62
- self.q_proj = nn.Linear(embed_dim, embed_dim, bias=False)
63
- self.k_proj = nn.Linear(kv_dim, embed_dim, bias=False)
64
- self.v_proj = nn.Linear(kv_dim, embed_dim, bias=False)
65
-
66
- self.multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
67
- self.linear = nn.Linear(embed_dim, embed_dim)
68
- self.dropout = nn.Dropout(drop_out_rate)
69
-
70
- self.layer_norm = nn.LayerNorm(embed_dim)
71
- self.ln_kv = nn.LayerNorm(kv_dim)
72
-
73
- def forward(self, x, hidden_states, attn_mask=None, add_residual=False):
74
- """
75
- Forward pass of the CrossAttention module.
76
-
77
- Args:
78
- x (torch.Tensor): Input tensor for key and value.
79
- hidden_states (torch.Tensor): Input tensor for query.
80
- attn_mask (torch.Tensor, optional): Attention mask. Default is None.
81
- add_residual (bool): Whether to add residual connection. Default is False.
82
-
83
- Returns:
84
- torch.Tensor: Output tensor after cross-attention.
85
- """
86
- normed_hidden_states = self.layer_norm(hidden_states)
87
- query = self.q_proj(normed_hidden_states).permute(1, 0, 2)
88
-
89
- x = self.ln_kv(x)
90
- key = self.k_proj(x).permute(1, 0, 2)
91
- value = self.v_proj(x).permute(1, 0, 2)
92
-
93
- attn_output, _ = self.multihead_attn(query, key, value, attn_mask=attn_mask)
94
-
95
- attn_output = attn_output.permute(1, 0, 2)
96
-
97
- if add_residual:
98
- attn_output = hidden_states + self.dropout(self.linear(attn_output))
99
- else:
100
- attn_output = self.dropout(self.linear(attn_output))
101
-
102
- return attn_output
103
-
104
-
105
- class AriaProjector(nn.Module):
106
- """
107
- A projection module with one cross attention layer and one FFN layer, which projects ViT's outputs into MoE's inputs.
108
-
109
- Args:
110
- patch_to_query_dict (dict): Maps patch numbers to their corresponding query numbers,
111
- e.g., {1225: 128, 4900: 256}. This allows for different query sizes based on image resolution.
112
- embed_dim (int): Embedding dimension.
113
- num_heads (int): Number of attention heads.
114
- kv_dim (int): Dimension of key and value.
115
- ff_dim (int): Hidden dimension of the feed-forward network.
116
- output_dim (int): Output dimension.
117
- norm_layer (nn.Module): Normalization layer. Default is nn.LayerNorm.
118
-
119
- Outputs:
120
- A tensor with the shape of (batch_size, query_number, output_dim)
121
- """
122
-
123
- def __init__(
124
- self,
125
- patch_to_query_dict,
126
- embed_dim,
127
- num_heads,
128
- kv_dim,
129
- ff_dim,
130
- output_dim,
131
- norm_layer=nn.LayerNorm,
132
- ):
133
- super().__init__()
134
- self.patch_to_query_dict = patch_to_query_dict
135
- self.embed_dim = embed_dim
136
- self.num_heads = num_heads
137
-
138
- self.query = nn.Parameter(
139
- torch.zeros(max(patch_to_query_dict.values()), self.embed_dim)
140
- )
141
-
142
- trunc_normal_(self.query, std=0.02)
143
-
144
- self.cross_attn = CrossAttention(kv_dim, embed_dim, num_heads)
145
-
146
- self.ln_ffn = norm_layer(embed_dim)
147
- self.ffn = FFN(embed_dim, ff_dim, output_dim)
148
-
149
- self.apply(self._init_weights)
150
-
151
- def _init_weights(self, m):
152
- if isinstance(m, nn.Linear):
153
- trunc_normal_(m.weight, std=0.02)
154
- if isinstance(m, nn.Linear) and m.bias is not None:
155
- nn.init.constant_(m.bias, 0)
156
- elif isinstance(m, nn.LayerNorm):
157
- nn.init.constant_(m.bias, 0)
158
- nn.init.constant_(m.weight, 1.0)
159
-
160
- def forward(self, x, attn_mask=None):
161
- """
162
- Forward pass of the Projector module.
163
-
164
- Args:
165
- x (torch.Tensor): Input tensor of shape (batch_size, num_patches, kv_dim).
166
- attn_mask (torch.Tensor, optional): Attention mask. Default is None.
167
-
168
- Returns:
169
- torch.Tensor: Output tensor of shape (batch_size, query_number, output_dim).
170
- """
171
- bs = x.shape[0]
172
- queries = self.query.unsqueeze(0).repeat(bs, 1, 1)
173
-
174
- query_num = self.patch_to_query_dict.get(x.shape[1], None)
175
- assert (
176
- query_num is not None
177
- ), f"Query number for {x.shape[1]} patches is not provided"
178
-
179
- queries = queries[:, :query_num, :]
180
-
181
- if attn_mask is not None:
182
- attn_mask = attn_mask.repeat_interleave(self.num_heads, 0)
183
- attn_mask = attn_mask.unsqueeze(1).expand(-1, queries.size(1), -1)
184
-
185
- attention_out = self.cross_attn(x, queries, attn_mask=attn_mask)
186
-
187
- out = self.ffn(self.ln_ffn(attention_out))
188
-
189
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
vision_encoder.py DELETED
@@ -1,152 +0,0 @@
1
- # Copyright 2024 Rhymes AI. All rights reserved.
2
- #
3
- # Licensed to the Apache Software Foundation (ASF) under one
4
- # or more contributor license agreements. See the NOTICE file
5
- # distributed with this work for additional information
6
- # regarding copyright ownership. The ASF licenses this file
7
- # to you under the Apache License, Version 2.0 (the
8
- # "License"); you may not use this file except in compliance
9
- # with the License. You may obtain a copy of the License at
10
- #
11
- # http://www.apache.org/licenses/LICENSE-2.0
12
- #
13
- # Unless required by applicable law or agreed to in writing,
14
- # software distributed under the License is distributed on an
15
- # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16
- # KIND, either express or implied. See the License for the
17
- # specific language governing permissions and limitations
18
- # under the License.
19
-
20
- """PyTorch Aria vision transformer."""
21
-
22
- from typing import Optional, Tuple, Union
23
-
24
- import torch
25
- import torch.utils.checkpoint
26
- from transformers import SiglipVisionConfig, SiglipVisionModel
27
- from transformers.modeling_outputs import BaseModelOutputWithPooling
28
- from transformers.models.idefics2.modeling_idefics2 import Idefics2VisionTransformer
29
-
30
-
31
- class AriaVisionConfig(SiglipVisionConfig):
32
- """Configuration class for AriaVisionModel."""
33
-
34
- model_type = "aria_vision_model"
35
-
36
- def __init__(
37
- self,
38
- **kwargs,
39
- ):
40
- super().__init__(**kwargs)
41
-
42
-
43
- class IdentityOp(torch.nn.Module):
44
- """
45
- An identity operation that returns the input unchanged.
46
-
47
- This can be used as a placeholder or to maintain architectural consistency
48
- when a specific operation is not needed.
49
- """
50
-
51
- def __init__(self, *args, **kwargs):
52
- super().__init__()
53
-
54
- def forward(self, x, *args, **kwargs):
55
- return x
56
-
57
-
58
- class AriaVisionTransformer(Idefics2VisionTransformer):
59
- """
60
- Aria Vision Transformer model based on Idefics2VisionTransformer.
61
-
62
- This class extends the original Idefics2VisionTransformer by removing the post-layernorm operation.
63
- """
64
-
65
- def __init__(self, config: AriaVisionConfig):
66
- super().__init__(config)
67
- self.post_layernorm = IdentityOp()
68
-
69
-
70
- class AriaVisionModel(SiglipVisionModel):
71
- """
72
- Aria Vision Model extends SiglipVisionModel to support pixel_mask.
73
-
74
- The pixel_mask is a 2D boolean tensor that indicates which pixels in the input
75
- image are actual content and which are padding. It has the same height and width
76
- as the input image, where:
77
- - True (1) values represent pixels from the original image
78
- - False (0) values represent padding pixels
79
-
80
- This mask helps the model focus on the relevant parts of the image during processing.
81
- """
82
-
83
- config_class = AriaVisionConfig
84
- main_input_name = "pixel_values"
85
- _supports_sdpa = False
86
-
87
- def __init__(self, config: AriaVisionConfig):
88
- super().__init__(config)
89
- self.vision_model = AriaVisionTransformer(config)
90
-
91
- # Initialize weights and apply final processing
92
- self.post_init()
93
-
94
- def forward(
95
- self,
96
- pixel_values: torch.Tensor,
97
- pixel_mask: Optional[torch.BoolTensor] = None,
98
- output_attentions: Optional[bool] = None,
99
- output_hidden_states: Optional[bool] = None,
100
- return_dict: Optional[bool] = None,
101
- ) -> Union[Tuple, BaseModelOutputWithPooling]:
102
- """
103
- Forward pass of the AriaVisionModel.
104
-
105
- Args:
106
- pixel_values (torch.Tensor): The pixel values of the input images.
107
- pixel_mask (Optional[torch.BoolTensor]): Mask for the pixel values.
108
- output_attentions (Optional[bool]): Whether to output attentions.
109
- output_hidden_states (Optional[bool]): Whether to output hidden states.
110
- return_dict (Optional[bool]): Whether to return a ModelOutput object.
111
-
112
- Returns:
113
- Union[Tuple, BaseModelOutputWithPooling]: The model's output.
114
- """
115
- return_dict = (
116
- return_dict if return_dict is not None else self.config.use_return_dict
117
- )
118
- patch_attention_mask = self._create_patch_attention_mask(pixel_mask)
119
-
120
- vit_oup = self.vision_model(
121
- pixel_values=pixel_values,
122
- patch_attention_mask=patch_attention_mask,
123
- output_attentions=output_attentions,
124
- output_hidden_states=output_hidden_states,
125
- return_dict=return_dict,
126
- )
127
-
128
- image_atts = self._create_image_attention_mask(patch_attention_mask)
129
-
130
- return vit_oup, image_atts
131
-
132
- def _create_patch_attention_mask(self, pixel_mask):
133
- if pixel_mask is None:
134
- return None
135
-
136
- patches_subgrid = pixel_mask.unfold(
137
- dimension=1,
138
- size=self.vision_model.config.patch_size,
139
- step=self.vision_model.config.patch_size,
140
- ).unfold(
141
- dimension=2,
142
- size=self.vision_model.config.patch_size,
143
- step=self.vision_model.config.patch_size,
144
- )
145
- return (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
146
-
147
- def _create_image_attention_mask(self, patch_attention_mask):
148
- if patch_attention_mask is None:
149
- return None
150
-
151
- flattened_mask = patch_attention_mask.flatten(1)
152
- return torch.logical_not(flattened_mask)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
vision_processor.py DELETED
@@ -1,321 +0,0 @@
1
- # Copyright 2024 Rhymes AI. All rights reserved.
2
- #
3
- # Licensed to the Apache Software Foundation (ASF) under one
4
- # or more contributor license agreements. See the NOTICE file
5
- # distributed with this work for additional information
6
- # regarding copyright ownership. The ASF licenses this file
7
- # to you under the Apache License, Version 2.0 (the
8
- # "License"); you may not use this file except in compliance
9
- # with the License. You may obtain a copy of the License at
10
- #
11
- # http://www.apache.org/licenses/LICENSE-2.0
12
- #
13
- # Unless required by applicable law or agreed to in writing,
14
- # software distributed under the License is distributed on an
15
- # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16
- # KIND, either express or implied. See the License for the
17
- # specific language governing permissions and limitations
18
- # under the License.
19
-
20
- from typing import List, Optional, Union
21
-
22
- import numpy as np
23
- import torch
24
- from PIL import Image, ImageOps
25
- from torchvision import transforms
26
- from transformers import BaseImageProcessor, BatchFeature, TensorType
27
-
28
-
29
- def _select_best_resolution(
30
- img_width: int, img_height: int, target_ratios: List[List[int]], patch_size: int
31
- ):
32
- """
33
- Selects the best resolution from a list of possible resolutions based on the original size.
34
-
35
- Args:
36
- img_width: the original widths of images.
37
- img_height: the original heights of images.
38
- target_ratios (2d numpy array): dimension size (M,2)
39
- patch_size (int): image patch size
40
-
41
- Returns:
42
- tuple: The best fit resolution in the format (width, height).
43
- """
44
-
45
- aspect_ratio = img_width / img_height
46
- best_ratio_diff = float("inf")
47
- best_ratio_w, best_ratio_h = 1, 1
48
- area = np.int32(img_width) * np.int32(img_height)
49
- for ratio in target_ratios:
50
- target_aspect_ratio = ratio[0] / ratio[1]
51
- ratio_diff = abs(aspect_ratio - target_aspect_ratio)
52
- if ratio_diff < best_ratio_diff:
53
- best_ratio_diff = ratio_diff
54
- best_ratio_w, best_ratio_h = ratio[0], ratio[1]
55
- elif (
56
- ratio_diff == best_ratio_diff
57
- and area > 0.5 * patch_size * patch_size * ratio[0] * ratio[1]
58
- ):
59
- best_ratio_w, best_ratio_h = ratio[0], ratio[1]
60
-
61
- return best_ratio_w, best_ratio_h
62
-
63
-
64
- def _split_image(
65
- image: Image.Image,
66
- split_image: bool,
67
- split_ratio: List[List[int]],
68
- patch_size: int,
69
- ) -> List[Image.Image]:
70
- """
71
- Split image into multiple patches
72
-
73
- Args:
74
- image (PIL.Image): Input image.
75
- split_image (bool): Whether to split the image into patches.
76
- split_ratio (2d numpy array): dimension size (M,2)
77
- patch_size (int): image patch size
78
-
79
- Returns:
80
- List[PIL.Image]: List of splitted images.
81
- """
82
- if split_image:
83
- ratio_width, ratio_height = _select_best_resolution(
84
- image.width, image.height, split_ratio, patch_size
85
- )
86
- resize_width = patch_size * ratio_width
87
- resize_height = patch_size * ratio_height
88
- blocks = ratio_width * ratio_height
89
- resized_img = image.resize((resize_width, resize_height))
90
- processed_images = []
91
- for i in range(blocks):
92
- box = (
93
- (i % (resize_width // patch_size)) * patch_size,
94
- (i // (resize_width // patch_size)) * patch_size,
95
- ((i % (resize_width // patch_size)) + 1) * patch_size,
96
- ((i // (resize_width // patch_size)) + 1) * patch_size,
97
- )
98
- # split the image
99
- split_img = resized_img.crop(box)
100
- processed_images.append(split_img)
101
- assert len(processed_images) == blocks
102
- if len(processed_images) != 1:
103
- processed_images.insert(0, image)
104
- return processed_images
105
- else:
106
- return [image]
107
-
108
-
109
- def keep_ratio_resize_and_pixel_mask(
110
- img: Image.Image, max_size, min_size=336, padding_value=0
111
- ):
112
- """
113
- Resize an image while maintaining aspect ratio and create a pixel mask.
114
-
115
- Args:
116
- img (PIL.Image): Input image.
117
- max_size (int): Maximum size for the larger dimension of the image.
118
- min_size (int, optional): Minimum size for the smaller dimension. Defaults to 336.
119
- padding_value (int, optional): Value used for padding. Defaults to 0.
120
-
121
- Returns:
122
- tuple: A tuple containing:
123
- - PIL.Image: Resized and padded image.
124
- - torch.Tensor: Boolean pixel mask. This mask is a 2D tensor of shape (max_size, max_size) where:
125
- - True (1) values indicate pixels that belong to the original resized image.
126
- - False (0) values indicate pixels that are part of the padding.
127
- The mask helps distinguish between actual image content and padded areas in subsequent processing steps.
128
- """
129
- img = img.convert("RGB")
130
- # rescale the given image, keep the aspect ratio
131
- scale = max_size / max(img.size)
132
-
133
- w, h = img.size
134
- if w >= h:
135
- new_size = (max_size, max(int(h * scale), min_size)) # w, h
136
- else:
137
- new_size = (max(int(w * scale), min_size), max_size) # w, h
138
-
139
- img_resized = img.resize(new_size, resample=Image.Resampling.BICUBIC)
140
-
141
- # padding the right/bottom
142
- padding_right, padding_bottom = max_size - new_size[0], max_size - new_size[1]
143
- img_padded = ImageOps.expand(
144
- img_resized, (0, 0, padding_right, padding_bottom), fill=padding_value
145
- )
146
-
147
- # Create a pixel mask
148
- pixel_mask = torch.zeros(max_size, max_size)
149
- pixel_mask[: new_size[1], : new_size[0]] = 1
150
- pixel_mask = pixel_mask.bool()
151
- return img_padded, pixel_mask
152
-
153
-
154
- class AriaVisionProcessor(BaseImageProcessor):
155
- """
156
- A vision processor for the Aria model that handles image preprocessing.
157
- """
158
-
159
- def __init__(
160
- self,
161
- max_image_size=980,
162
- min_image_size=336,
163
- image_mean=[0.5, 0.5, 0.5],
164
- image_std=[0.5, 0.5, 0.5],
165
- **kwargs,
166
- ):
167
- """
168
- Initialize the AriaVisionProcessor.
169
-
170
- Args:
171
- max_image_size (int, optional): Maximum image size. Defaults to 980.
172
- min_image_size (int, optional): Minimum image size. Defaults to 336.
173
- mean (list, optional): Mean values for normalization. Defaults to [0.5, 0.5, 0.5].
174
- std (list, optional): Standard deviation values for normalization. Defaults to [0.5, 0.5, 0.5].
175
- """
176
- super().__init__(**kwargs)
177
-
178
- self.max_image_size = max_image_size
179
- self.min_image_size = min_image_size
180
- self.image_mean = image_mean
181
- self.image_std = image_std
182
- self.auto_map = {
183
- "AutoProcessor": "processing_aria.AriaProcessor",
184
- "AutoImageProcessor": "vision_processor.AriaVisionProcessor",
185
- }
186
-
187
- # we make the transform a property so that it is lazily initialized,
188
- # this could avoid the error "TypeError: Object of type Normalize is not JSON serializable"
189
- # when we used save_pretrained or from_pretrained.
190
- self._transform = None
191
- self._set_processor_class("AriaProcessor")
192
-
193
- @property
194
- def transform(self):
195
- if self._transform is None:
196
- # Recreate the transform when accessed
197
- self._transform = transforms.Compose(
198
- [
199
- transforms.ToTensor(),
200
- transforms.Normalize(self.image_mean, self.image_std),
201
- ]
202
- )
203
- return self._transform
204
-
205
- def __call__(
206
- self,
207
- images: Union[Image.Image, List[Image.Image]],
208
- max_image_size: Optional[int] = 980,
209
- min_image_size: Optional[int] = 336,
210
- return_tensors: Optional[Union[str, TensorType]] = "pt",
211
- split_image: Optional[bool] = False,
212
- split_ratio: Optional[List[List[int]]] = [
213
- [1, 2],
214
- [1, 3],
215
- [1, 4],
216
- [1, 5],
217
- [1, 6],
218
- [1, 7],
219
- [1, 8],
220
- [2, 4],
221
- [2, 3],
222
- [2, 2],
223
- [2, 1],
224
- [3, 1],
225
- [3, 2],
226
- [4, 1],
227
- [4, 2],
228
- [5, 1],
229
- [6, 1],
230
- [7, 1],
231
- [8, 1],
232
- ],
233
- ):
234
- """
235
- Process a list of images.
236
-
237
- Args:
238
- images (list): List of PIL.Image objects.
239
- max_image_size (int, optional): Override the default max image size. Defaults to None.
240
- return_tensors (str or TensorType, optional): The type of tensor to return. Defaults to "pt".
241
- split_image (bool, optional): Whether to split the image. Defaults to False.
242
- split_ratio (list, optional): The ratio for splitting the image. Defaults to a list of common split ratios.
243
- Returns:
244
- BatchFeature: A BatchFeature object containing:
245
- - 'pixel_values': Tensor of processed image pixel values.
246
- - 'pixel_mask': Boolean pixel mask. This mask is a 2D tensor of shape (max_size, max_size) where:
247
- - True (1) values indicate pixels that belong to the original resized image.
248
- - False (0) values indicate pixels that are part of the padding.
249
- The mask helps distinguish between actual image content and padded areas in subsequent processing steps.
250
- - 'num_crops': Tensor of the number of crops for each image.
251
- """
252
- max_size = self.max_image_size if max_image_size is None else max_image_size
253
- min_size = self.min_image_size if min_image_size is None else min_image_size
254
-
255
- if max_size not in [490, 980]:
256
- raise ValueError("max_image_size must be either 490 or 980")
257
-
258
- if isinstance(images, Image.Image):
259
- images = [images]
260
-
261
- pixel_values = []
262
- pixel_masks = []
263
- num_crops = []
264
-
265
- for image in images:
266
- crop_images = _split_image(image, split_image, split_ratio, max_size)
267
- num_crops.append(torch.tensor(len(crop_images)))
268
- for crop_image in crop_images:
269
- img_padded, pixel_mask = keep_ratio_resize_and_pixel_mask(
270
- crop_image, max_size, min_size
271
- )
272
- img_padded = self.transform(img_padded)
273
- pixel_values.append(img_padded)
274
- pixel_masks.append(pixel_mask)
275
-
276
- return BatchFeature(
277
- data={
278
- "pixel_values": torch.stack(pixel_values),
279
- "pixel_mask": torch.stack(pixel_masks),
280
- "num_crops": torch.stack(num_crops),
281
- },
282
- tensor_type=return_tensors,
283
- )
284
-
285
- def preprocess(
286
- self,
287
- images,
288
- max_image_size=None,
289
- min_image_size=None,
290
- return_tensors: Optional[Union[str, TensorType]] = None,
291
- split_image: Optional[bool] = False,
292
- split_ratio: Optional[List[List[int]]] = [
293
- [1, 2],
294
- [1, 3],
295
- [1, 4],
296
- [1, 5],
297
- [1, 6],
298
- [1, 7],
299
- [1, 8],
300
- [2, 4],
301
- [2, 3],
302
- [2, 2],
303
- [2, 1],
304
- [3, 1],
305
- [3, 2],
306
- [4, 1],
307
- [4, 2],
308
- [5, 1],
309
- [6, 1],
310
- [7, 1],
311
- [8, 1],
312
- ],
313
- ):
314
- return self.__call__(
315
- images,
316
- max_image_size=max_image_size,
317
- min_image_size=min_image_size,
318
- return_tensors=return_tensors,
319
- split_image=split_image,
320
- split_ratio=split_ratio,
321
- )