radna commited on
Commit
a836c79
·
verified ·
1 Parent(s): 2b894e3

Update modeling_intern_vit.py

Browse files
Files changed (1) hide show
  1. modeling_intern_vit.py +365 -363
modeling_intern_vit.py CHANGED
@@ -1,363 +1,365 @@
1
- # --------------------------------------------------------
2
- # InternVL
3
- # Copyright (c) 2023 OpenGVLab
4
- # Licensed under The MIT License [see LICENSE for details]
5
- # --------------------------------------------------------
6
- from typing import Optional, Tuple, Union
7
-
8
- import torch
9
- import torch.nn.functional as F
10
- import torch.utils.checkpoint
11
- from einops import rearrange
12
- from timm.models.layers import DropPath
13
- from torch import nn
14
- from transformers.activations import ACT2FN
15
- from transformers.modeling_outputs import (BaseModelOutput,
16
- BaseModelOutputWithPooling)
17
- from transformers.modeling_utils import PreTrainedModel
18
- from transformers.utils import logging
19
-
20
- from .configuration_intern_vit import InternVisionConfig
21
-
22
- try:
23
- from .flash_attention import FlashAttention
24
- has_flash_attn = True
25
- except:
26
- print('FlashAttention is not installed.')
27
- has_flash_attn = False
28
-
29
-
30
- logger = logging.get_logger(__name__)
31
-
32
-
33
- class InternRMSNorm(nn.Module):
34
- def __init__(self, hidden_size, eps=1e-6):
35
- super().__init__()
36
- self.weight = nn.Parameter(torch.ones(hidden_size))
37
- self.variance_epsilon = eps
38
-
39
- def forward(self, hidden_states):
40
- input_dtype = hidden_states.dtype
41
- hidden_states = hidden_states.to(torch.float32)
42
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
43
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
44
- return self.weight * hidden_states.to(input_dtype)
45
-
46
-
47
- try:
48
- from apex.normalization import FusedRMSNorm
49
-
50
- InternRMSNorm = FusedRMSNorm # noqa
51
-
52
- logger.info('Discovered apex.normalization.FusedRMSNorm - will use it instead of InternRMSNorm')
53
- except ImportError:
54
- # using the normal InternRMSNorm
55
- pass
56
- except Exception:
57
- logger.warning('discovered apex but it failed to load, falling back to InternRMSNorm')
58
- pass
59
-
60
-
61
- NORM2FN = {
62
- 'rms_norm': InternRMSNorm,
63
- 'layer_norm': nn.LayerNorm,
64
- }
65
-
66
-
67
- class InternVisionEmbeddings(nn.Module):
68
- def __init__(self, config: InternVisionConfig):
69
- super().__init__()
70
- self.config = config
71
- self.embed_dim = config.hidden_size
72
- self.image_size = config.image_size
73
- self.patch_size = config.patch_size
74
-
75
- self.class_embedding = nn.Parameter(
76
- torch.randn(1, 1, self.embed_dim),
77
- )
78
-
79
- self.patch_embedding = nn.Conv2d(
80
- in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size
81
- )
82
-
83
- self.num_patches = (self.image_size // self.patch_size) ** 2
84
- self.num_positions = self.num_patches + 1
85
-
86
- self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim))
87
-
88
- def _get_pos_embed(self, pos_embed, H, W):
89
- target_dtype = pos_embed.dtype
90
- pos_embed = pos_embed.float().reshape(
91
- 1, self.image_size // self.patch_size, self.image_size // self.patch_size, -1).permute(0, 3, 1, 2)
92
- pos_embed = F.interpolate(pos_embed, size=(H, W), mode='bicubic', align_corners=False).\
93
- reshape(1, -1, H * W).permute(0, 2, 1).to(target_dtype)
94
- return pos_embed
95
-
96
- def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
97
- target_dtype = self.patch_embedding.weight.dtype
98
- patch_embeds = self.patch_embedding(pixel_values) # shape = [*, channel, width, height]
99
- batch_size, _, height, width = patch_embeds.shape
100
- patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
101
- class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype)
102
- embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
103
- position_embedding = torch.cat([
104
- self.position_embedding[:, :1, :],
105
- self._get_pos_embed(self.position_embedding[:, 1:, :], height, width)
106
- ], dim=1)
107
- embeddings = embeddings + position_embedding.to(target_dtype)
108
- return embeddings
109
-
110
-
111
- class InternAttention(nn.Module):
112
- """Multi-headed attention from 'Attention Is All You Need' paper"""
113
-
114
- def __init__(self, config: InternVisionConfig):
115
- super().__init__()
116
- self.config = config
117
- self.embed_dim = config.hidden_size
118
- self.num_heads = config.num_attention_heads
119
- self.use_flash_attn = config.use_flash_attn and has_flash_attn
120
- if config.use_flash_attn and not has_flash_attn:
121
- print('Warning: Flash Attention is not available, use_flash_attn is set to False.')
122
- self.head_dim = self.embed_dim // self.num_heads
123
- if self.head_dim * self.num_heads != self.embed_dim:
124
- raise ValueError(
125
- f'embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:'
126
- f' {self.num_heads}).'
127
- )
128
-
129
- self.scale = self.head_dim ** -0.5
130
- self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=config.qkv_bias)
131
- self.attn_drop = nn.Dropout(config.attention_dropout)
132
- self.proj_drop = nn.Dropout(config.dropout)
133
-
134
- self.qk_normalization = config.qk_normalization
135
-
136
- if self.qk_normalization:
137
- self.q_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
138
- self.k_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
139
-
140
- if self.use_flash_attn:
141
- self.inner_attn = FlashAttention(attention_dropout=config.attention_dropout)
142
- self.proj = nn.Linear(self.embed_dim, self.embed_dim)
143
-
144
- def _naive_attn(self, x):
145
- B, N, C = x.shape
146
- qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
147
- q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
148
-
149
- if self.qk_normalization:
150
- B_, H_, N_, D_ = q.shape
151
- q = self.q_norm(q.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
152
- k = self.k_norm(k.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
153
-
154
- attn = ((q * self.scale) @ k.transpose(-2, -1))
155
- attn = attn.softmax(dim=-1)
156
- attn = self.attn_drop(attn)
157
-
158
- x = (attn @ v).transpose(1, 2).reshape(B, N, C)
159
- x = self.proj(x)
160
- x = self.proj_drop(x)
161
- return x
162
-
163
- def _flash_attn(self, x, key_padding_mask=None, need_weights=False):
164
- qkv = self.qkv(x)
165
- qkv = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3, h=self.num_heads)
166
-
167
- if self.qk_normalization:
168
- q, k, v = qkv.unbind(2)
169
- q = self.q_norm(q.flatten(-2, -1)).view(q.shape)
170
- k = self.k_norm(k.flatten(-2, -1)).view(k.shape)
171
- qkv = torch.stack([q, k, v], dim=2)
172
-
173
- context, _ = self.inner_attn(
174
- qkv, key_padding_mask=key_padding_mask, need_weights=need_weights, causal=False
175
- )
176
- outs = self.proj(rearrange(context, 'b s h d -> b s (h d)'))
177
- outs = self.proj_drop(outs)
178
- return outs
179
-
180
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
181
- x = self._naive_attn(hidden_states) if not self.use_flash_attn else self._flash_attn(hidden_states)
182
- return x
183
-
184
-
185
- class InternMLP(nn.Module):
186
- def __init__(self, config: InternVisionConfig):
187
- super().__init__()
188
- self.config = config
189
- self.act = ACT2FN[config.hidden_act]
190
- self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
191
- self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
192
-
193
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
194
- hidden_states = self.fc1(hidden_states)
195
- hidden_states = self.act(hidden_states)
196
- hidden_states = self.fc2(hidden_states)
197
- return hidden_states
198
-
199
-
200
- class InternVisionEncoderLayer(nn.Module):
201
- def __init__(self, config: InternVisionConfig, drop_path_rate: float):
202
- super().__init__()
203
- self.embed_dim = config.hidden_size
204
- self.intermediate_size = config.intermediate_size
205
- self.norm_type = config.norm_type
206
-
207
- self.attn = InternAttention(config)
208
- self.mlp = InternMLP(config)
209
- self.norm1 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps)
210
- self.norm2 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps)
211
-
212
- self.ls1 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
213
- self.ls2 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
214
- self.drop_path1 = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
215
- self.drop_path2 = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
216
-
217
- def forward(
218
- self,
219
- hidden_states: torch.Tensor,
220
- ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor], Optional[Tuple[torch.FloatTensor]]]:
221
- """
222
- Args:
223
- hidden_states (`Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]`): input to the layer of shape `(batch, seq_len, embed_dim)`
224
- """
225
- hidden_states = hidden_states + self.drop_path1(self.attn(self.norm1(hidden_states)) * self.ls1)
226
-
227
- hidden_states = hidden_states + self.drop_path2(self.mlp(self.norm2(hidden_states)) * self.ls2)
228
-
229
- return hidden_states
230
-
231
-
232
- class InternVisionEncoder(nn.Module):
233
- """
234
- Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
235
- [`InternEncoderLayer`].
236
-
237
- Args:
238
- config (`InternConfig`):
239
- The corresponding vision configuration for the `InternEncoder`.
240
- """
241
-
242
- def __init__(self, config: InternVisionConfig):
243
- super().__init__()
244
- self.config = config
245
- # stochastic depth decay rule
246
- dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers)]
247
- self.layers = nn.ModuleList([
248
- InternVisionEncoderLayer(config, dpr[idx]) for idx in range(config.num_hidden_layers)])
249
- self.gradient_checkpointing = True
250
-
251
- def forward(
252
- self,
253
- inputs_embeds,
254
- output_hidden_states: Optional[bool] = None,
255
- return_dict: Optional[bool] = None,
256
- ) -> Union[Tuple, BaseModelOutput]:
257
- r"""
258
- Args:
259
- inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
260
- Embedded representation of the inputs. Should be float, not int tokens.
261
- output_hidden_states (`bool`, *optional*):
262
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
263
- for more detail.
264
- return_dict (`bool`, *optional*):
265
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
266
- """
267
- output_hidden_states = (
268
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
269
- )
270
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
271
-
272
- encoder_states = () if output_hidden_states else None
273
- hidden_states = inputs_embeds
274
-
275
- for idx, encoder_layer in enumerate(self.layers):
276
- if output_hidden_states:
277
- encoder_states = encoder_states + (hidden_states,)
278
- if self.gradient_checkpointing and self.training:
279
- layer_outputs = torch.utils.checkpoint.checkpoint(
280
- encoder_layer,
281
- hidden_states)
282
- else:
283
- layer_outputs = encoder_layer(
284
- hidden_states,
285
- )
286
- hidden_states = layer_outputs
287
-
288
- if output_hidden_states:
289
- encoder_states = encoder_states + (hidden_states,)
290
-
291
- if not return_dict:
292
- return tuple(v for v in [hidden_states, encoder_states] if v is not None)
293
- return BaseModelOutput(
294
- last_hidden_state=hidden_states, hidden_states=encoder_states
295
- )
296
-
297
-
298
- class InternVisionModel(PreTrainedModel):
299
- main_input_name = 'pixel_values'
300
- config_class = InternVisionConfig
301
- _no_split_modules = ['InternVisionEncoderLayer']
302
-
303
- def __init__(self, config: InternVisionConfig):
304
- super().__init__(config)
305
- self.config = config
306
-
307
- self.embeddings = InternVisionEmbeddings(config)
308
- self.encoder = InternVisionEncoder(config)
309
-
310
- def resize_pos_embeddings(self, old_size, new_size, patch_size):
311
- pos_emb = self.embeddings.position_embedding
312
- _, num_positions, embed_dim = pos_emb.shape
313
- cls_emb = pos_emb[:, :1, :]
314
- pos_emb = pos_emb[:, 1:, :].reshape(1, old_size // patch_size, old_size // patch_size, -1).permute(0, 3, 1, 2)
315
- pos_emb = F.interpolate(pos_emb.float(), size=new_size // patch_size, mode='bicubic', align_corners=False)
316
- pos_emb = pos_emb.to(cls_emb.dtype).reshape(1, embed_dim, -1).permute(0, 2, 1)
317
- pos_emb = torch.cat([cls_emb, pos_emb], dim=1)
318
- self.embeddings.position_embedding = nn.Parameter(pos_emb)
319
- self.embeddings.image_size = new_size
320
- logger.info('Resized position embeddings from {} to {}'.format(old_size, new_size))
321
-
322
- def get_input_embeddings(self):
323
- return self.embeddings
324
-
325
- def forward(
326
- self,
327
- pixel_values: Optional[torch.FloatTensor] = None,
328
- output_hidden_states: Optional[bool] = None,
329
- return_dict: Optional[bool] = None,
330
- pixel_embeds: Optional[torch.FloatTensor] = None,
331
- ) -> Union[Tuple, BaseModelOutputWithPooling]:
332
- output_hidden_states = (
333
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
334
- )
335
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
336
-
337
- if pixel_values is None and pixel_embeds is None:
338
- raise ValueError('You have to specify pixel_values or pixel_embeds')
339
-
340
- if pixel_embeds is not None:
341
- hidden_states = pixel_embeds
342
- else:
343
- if len(pixel_values.shape) == 4:
344
- hidden_states = self.embeddings(pixel_values)
345
- else:
346
- raise ValueError(f'wrong pixel_values size: {pixel_values.shape}')
347
- encoder_outputs = self.encoder(
348
- inputs_embeds=hidden_states,
349
- output_hidden_states=output_hidden_states,
350
- return_dict=return_dict,
351
- )
352
- last_hidden_state = encoder_outputs.last_hidden_state
353
- pooled_output = last_hidden_state[:, 0, :]
354
-
355
- if not return_dict:
356
- return (last_hidden_state, pooled_output) + encoder_outputs[1:]
357
-
358
- return BaseModelOutputWithPooling(
359
- last_hidden_state=last_hidden_state,
360
- pooler_output=pooled_output,
361
- hidden_states=encoder_outputs.hidden_states,
362
- attentions=encoder_outputs.attentions,
363
- )
 
 
 
1
+ # --------------------------------------------------------
2
+ # InternVL
3
+ # Copyright (c) 2023 OpenGVLab
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # --------------------------------------------------------
6
+ from typing import Optional, Tuple, Union
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+ import torch.utils.checkpoint
11
+ from einops import rearrange
12
+ from timm.models.layers import DropPath
13
+ from torch import nn
14
+ from transformers.activations import ACT2FN
15
+ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
16
+ from transformers.modeling_utils import PreTrainedModel
17
+ from transformers.utils import logging
18
+
19
+ from .configuration_intern_vit import InternVisionConfig
20
+
21
+
22
+ try:
23
+ from triton_flash_atn import _attention
24
+
25
+ from triton_bert_pading import pad_input, unpad_input
26
+
27
+ has_flash_attn = True
28
+ except:
29
+ print("FlashAttention is not installed.")
30
+ has_flash_attn = False
31
+
32
+ logger = logging.get_logger(__name__)
33
+
34
+
35
+ class InternRMSNorm(nn.Module):
36
+ def __init__(self, hidden_size, eps=1e-6):
37
+ super().__init__()
38
+ self.weight = nn.Parameter(torch.ones(hidden_size))
39
+ self.variance_epsilon = eps
40
+
41
+ def forward(self, hidden_states):
42
+ input_dtype = hidden_states.dtype
43
+ hidden_states = hidden_states.to(torch.float32)
44
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
45
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
46
+ return self.weight * hidden_states.to(input_dtype)
47
+
48
+
49
+ try:
50
+ from apex.normalization import FusedRMSNorm
51
+
52
+ InternRMSNorm = FusedRMSNorm # noqa
53
+
54
+ logger.info('Discovered apex.normalization.FusedRMSNorm - will use it instead of InternRMSNorm')
55
+ except ImportError:
56
+ # using the normal InternRMSNorm
57
+ pass
58
+ except Exception:
59
+ logger.warning('discovered apex but it failed to load, falling back to InternRMSNorm')
60
+ pass
61
+
62
+
63
+ NORM2FN = {
64
+ 'rms_norm': InternRMSNorm,
65
+ 'layer_norm': nn.LayerNorm,
66
+ }
67
+
68
+
69
+ class InternVisionEmbeddings(nn.Module):
70
+ def __init__(self, config: InternVisionConfig):
71
+ super().__init__()
72
+ self.config = config
73
+ self.embed_dim = config.hidden_size
74
+ self.image_size = config.image_size
75
+ self.patch_size = config.patch_size
76
+
77
+ self.class_embedding = nn.Parameter(
78
+ torch.randn(1, 1, self.embed_dim),
79
+ )
80
+
81
+ self.patch_embedding = nn.Conv2d(
82
+ in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size
83
+ )
84
+
85
+ self.num_patches = (self.image_size // self.patch_size) ** 2
86
+ self.num_positions = self.num_patches + 1
87
+
88
+ self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim))
89
+
90
+ def _get_pos_embed(self, pos_embed, H, W):
91
+ target_dtype = pos_embed.dtype
92
+ pos_embed = pos_embed.float().reshape(
93
+ 1, self.image_size // self.patch_size, self.image_size // self.patch_size, -1).permute(0, 3, 1, 2)
94
+ pos_embed = F.interpolate(pos_embed, size=(H, W), mode='bicubic', align_corners=False).\
95
+ reshape(1, -1, H * W).permute(0, 2, 1).to(target_dtype)
96
+ return pos_embed
97
+
98
+ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
99
+ target_dtype = self.patch_embedding.weight.dtype
100
+ patch_embeds = self.patch_embedding(pixel_values) # shape = [*, channel, width, height]
101
+ batch_size, _, height, width = patch_embeds.shape
102
+ patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
103
+ class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype)
104
+ embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
105
+ position_embedding = torch.cat([
106
+ self.position_embedding[:, :1, :],
107
+ self._get_pos_embed(self.position_embedding[:, 1:, :], height, width)
108
+ ], dim=1)
109
+ embeddings = embeddings + position_embedding.to(target_dtype)
110
+ return embeddings
111
+
112
+
113
+ class InternAttention(nn.Module):
114
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
115
+
116
+ def __init__(self, config: InternVisionConfig):
117
+ super().__init__()
118
+ self.config = config
119
+ self.embed_dim = config.hidden_size
120
+ self.num_heads = config.num_attention_heads
121
+ self.use_flash_attn = config.use_flash_attn and has_flash_attn
122
+ if config.use_flash_attn and not has_flash_attn:
123
+ print('Warning: Flash Attention is not available, use_flash_attn is set to False.')
124
+ self.head_dim = self.embed_dim // self.num_heads
125
+ if self.head_dim * self.num_heads != self.embed_dim:
126
+ raise ValueError(
127
+ f'embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:'
128
+ f' {self.num_heads}).'
129
+ )
130
+
131
+ self.scale = self.head_dim ** -0.5
132
+ self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=config.qkv_bias)
133
+ self.attn_drop = nn.Dropout(config.attention_dropout)
134
+ self.proj_drop = nn.Dropout(config.dropout)
135
+
136
+ self.qk_normalization = config.qk_normalization
137
+
138
+ if self.qk_normalization:
139
+ self.q_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
140
+ self.k_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
141
+
142
+ if self.use_flash_attn:
143
+ self.inner_attn = FlashAttention(attention_dropout=config.attention_dropout)
144
+ self.proj = nn.Linear(self.embed_dim, self.embed_dim)
145
+
146
+ def _naive_attn(self, x):
147
+ B, N, C = x.shape
148
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
149
+ q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
150
+
151
+ if self.qk_normalization:
152
+ B_, H_, N_, D_ = q.shape
153
+ q = self.q_norm(q.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
154
+ k = self.k_norm(k.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
155
+
156
+ attn = ((q * self.scale) @ k.transpose(-2, -1))
157
+ attn = attn.softmax(dim=-1)
158
+ attn = self.attn_drop(attn)
159
+
160
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
161
+ x = self.proj(x)
162
+ x = self.proj_drop(x)
163
+ return x
164
+
165
+ def _flash_attn(self, x, key_padding_mask=None, need_weights=False):
166
+ qkv = self.qkv(x)
167
+ qkv = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3, h=self.num_heads)
168
+
169
+ if self.qk_normalization:
170
+ q, k, v = qkv.unbind(2)
171
+ q = self.q_norm(q.flatten(-2, -1)).view(q.shape)
172
+ k = self.k_norm(k.flatten(-2, -1)).view(k.shape)
173
+ qkv = torch.stack([q, k, v], dim=2)
174
+
175
+ context, _ = self.inner_attn(
176
+ qkv, key_padding_mask=key_padding_mask, need_weights=need_weights, causal=False
177
+ )
178
+ outs = self.proj(rearrange(context, 'b s h d -> b s (h d)'))
179
+ outs = self.proj_drop(outs)
180
+ return outs
181
+
182
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
183
+ x = self._naive_attn(hidden_states) if not self.use_flash_attn else self._flash_attn(hidden_states)
184
+ return x
185
+
186
+
187
+ class InternMLP(nn.Module):
188
+ def __init__(self, config: InternVisionConfig):
189
+ super().__init__()
190
+ self.config = config
191
+ self.act = ACT2FN[config.hidden_act]
192
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
193
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
194
+
195
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
196
+ hidden_states = self.fc1(hidden_states)
197
+ hidden_states = self.act(hidden_states)
198
+ hidden_states = self.fc2(hidden_states)
199
+ return hidden_states
200
+
201
+
202
+ class InternVisionEncoderLayer(nn.Module):
203
+ def __init__(self, config: InternVisionConfig, drop_path_rate: float):
204
+ super().__init__()
205
+ self.embed_dim = config.hidden_size
206
+ self.intermediate_size = config.intermediate_size
207
+ self.norm_type = config.norm_type
208
+
209
+ self.attn = InternAttention(config)
210
+ self.mlp = InternMLP(config)
211
+ self.norm1 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps)
212
+ self.norm2 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps)
213
+
214
+ self.ls1 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
215
+ self.ls2 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
216
+ self.drop_path1 = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
217
+ self.drop_path2 = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
218
+
219
+ def forward(
220
+ self,
221
+ hidden_states: torch.Tensor,
222
+ ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor], Optional[Tuple[torch.FloatTensor]]]:
223
+ """
224
+ Args:
225
+ hidden_states (`Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]`): input to the layer of shape `(batch, seq_len, embed_dim)`
226
+ """
227
+ hidden_states = hidden_states + self.drop_path1(self.attn(self.norm1(hidden_states)) * self.ls1)
228
+
229
+ hidden_states = hidden_states + self.drop_path2(self.mlp(self.norm2(hidden_states)) * self.ls2)
230
+
231
+ return hidden_states
232
+
233
+
234
+ class InternVisionEncoder(nn.Module):
235
+ """
236
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
237
+ [`InternEncoderLayer`].
238
+
239
+ Args:
240
+ config (`InternConfig`):
241
+ The corresponding vision configuration for the `InternEncoder`.
242
+ """
243
+
244
+ def __init__(self, config: InternVisionConfig):
245
+ super().__init__()
246
+ self.config = config
247
+ # stochastic depth decay rule
248
+ dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers)]
249
+ self.layers = nn.ModuleList([
250
+ InternVisionEncoderLayer(config, dpr[idx]) for idx in range(config.num_hidden_layers)])
251
+ self.gradient_checkpointing = True
252
+
253
+ def forward(
254
+ self,
255
+ inputs_embeds,
256
+ output_hidden_states: Optional[bool] = None,
257
+ return_dict: Optional[bool] = None,
258
+ ) -> Union[Tuple, BaseModelOutput]:
259
+ r"""
260
+ Args:
261
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
262
+ Embedded representation of the inputs. Should be float, not int tokens.
263
+ output_hidden_states (`bool`, *optional*):
264
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
265
+ for more detail.
266
+ return_dict (`bool`, *optional*):
267
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
268
+ """
269
+ output_hidden_states = (
270
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
271
+ )
272
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
273
+
274
+ encoder_states = () if output_hidden_states else None
275
+ hidden_states = inputs_embeds
276
+
277
+ for idx, encoder_layer in enumerate(self.layers):
278
+ if output_hidden_states:
279
+ encoder_states = encoder_states + (hidden_states,)
280
+ if self.gradient_checkpointing and self.training:
281
+ layer_outputs = torch.utils.checkpoint.checkpoint(
282
+ encoder_layer,
283
+ hidden_states)
284
+ else:
285
+ layer_outputs = encoder_layer(
286
+ hidden_states,
287
+ )
288
+ hidden_states = layer_outputs
289
+
290
+ if output_hidden_states:
291
+ encoder_states = encoder_states + (hidden_states,)
292
+
293
+ if not return_dict:
294
+ return tuple(v for v in [hidden_states, encoder_states] if v is not None)
295
+ return BaseModelOutput(
296
+ last_hidden_state=hidden_states, hidden_states=encoder_states
297
+ )
298
+
299
+
300
+ class InternVisionModel(PreTrainedModel):
301
+ main_input_name = 'pixel_values'
302
+ config_class = InternVisionConfig
303
+ _no_split_modules = ['InternVisionEncoderLayer']
304
+
305
+ def __init__(self, config: InternVisionConfig):
306
+ super().__init__(config)
307
+ self.config = config
308
+
309
+ self.embeddings = InternVisionEmbeddings(config)
310
+ self.encoder = InternVisionEncoder(config)
311
+
312
+ def resize_pos_embeddings(self, old_size, new_size, patch_size):
313
+ pos_emb = self.embeddings.position_embedding
314
+ _, num_positions, embed_dim = pos_emb.shape
315
+ cls_emb = pos_emb[:, :1, :]
316
+ pos_emb = pos_emb[:, 1:, :].reshape(1, old_size // patch_size, old_size // patch_size, -1).permute(0, 3, 1, 2)
317
+ pos_emb = F.interpolate(pos_emb.float(), size=new_size // patch_size, mode='bicubic', align_corners=False)
318
+ pos_emb = pos_emb.to(cls_emb.dtype).reshape(1, embed_dim, -1).permute(0, 2, 1)
319
+ pos_emb = torch.cat([cls_emb, pos_emb], dim=1)
320
+ self.embeddings.position_embedding = nn.Parameter(pos_emb)
321
+ self.embeddings.image_size = new_size
322
+ logger.info('Resized position embeddings from {} to {}'.format(old_size, new_size))
323
+
324
+ def get_input_embeddings(self):
325
+ return self.embeddings
326
+
327
+ def forward(
328
+ self,
329
+ pixel_values: Optional[torch.FloatTensor] = None,
330
+ output_hidden_states: Optional[bool] = None,
331
+ return_dict: Optional[bool] = None,
332
+ pixel_embeds: Optional[torch.FloatTensor] = None,
333
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
334
+ output_hidden_states = (
335
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
336
+ )
337
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
338
+
339
+ if pixel_values is None and pixel_embeds is None:
340
+ raise ValueError('You have to specify pixel_values or pixel_embeds')
341
+
342
+ if pixel_embeds is not None:
343
+ hidden_states = pixel_embeds
344
+ else:
345
+ if len(pixel_values.shape) == 4:
346
+ hidden_states = self.embeddings(pixel_values)
347
+ else:
348
+ raise ValueError(f'wrong pixel_values size: {pixel_values.shape}')
349
+ encoder_outputs = self.encoder(
350
+ inputs_embeds=hidden_states,
351
+ output_hidden_states=output_hidden_states,
352
+ return_dict=return_dict,
353
+ )
354
+ last_hidden_state = encoder_outputs.last_hidden_state
355
+ pooled_output = last_hidden_state[:, 0, :]
356
+
357
+ if not return_dict:
358
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
359
+
360
+ return BaseModelOutputWithPooling(
361
+ last_hidden_state=last_hidden_state,
362
+ pooler_output=pooled_output,
363
+ hidden_states=encoder_outputs.hidden_states,
364
+ attentions=encoder_outputs.attentions,
365
+ )