Files changed (3) hide show
  1. README.md +0 -7
  2. modeling_minicpmv.py +29 -42
  3. resampler.py +7 -662
README.md CHANGED
@@ -198,11 +198,4 @@ If you find our work helpful, please consider citing the following papers
198
  journal={arXiv preprint arXiv:2403.11703},
199
  year={2024}
200
  }
201
- @article{yao2024minicpmvgpt4vlevelmllm,
202
- title={MiniCPM-V: A GPT-4V Level MLLM on Your Phone},
203
- author={Yao, Yuan and Yu, Tianyu and Zhang, Ao and Wang, Chongyi and Cui, Junbo and Zhu, Hongji and Cai, Tianchi and Li, Haoyu and Zhao, Weilin and He, Zhihui and Chen, Qianyu and Zhou, Huarong and Zou, Zhensheng and Zhang, Haoye and Hu, Shengding and Zheng, Zhi and Zhou, Jie and Cai, Jie and Han, Xu and Zeng, Guoyang and Li, Dahai and Liu, Zhiyuan and Sun, Maosong},
204
- journal={arXiv preprint arXiv:2408.01800},
205
- year={2024},
206
- url={https://arxiv.org/abs/2408.01800},
207
- }
208
  ```
 
198
  journal={arXiv preprint arXiv:2403.11703},
199
  year={2024}
200
  }
 
 
 
 
 
 
 
201
  ```
modeling_minicpmv.py CHANGED
@@ -1,4 +1,5 @@
1
  import math
 
2
  import json
3
  import timm
4
  import torch
@@ -7,13 +8,10 @@ from PIL import Image
7
  from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
8
  from torchvision import transforms
9
  from transformers import LlamaTokenizer
10
- from transformers.integrations import is_deepspeed_zero3_enabled
11
  from .configuration_minicpm import MiniCPMVConfig
12
  from .modeling_minicpm import MiniCPMForCausalLM, MiniCPMPreTrainedModel
13
  from .resampler import Resampler
14
- from functools import partial
15
- from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
16
- from peft.utils.other import ModulesToSaveWrapper
17
 
18
 
19
  class MiniCPMVPreTrainedModel(MiniCPMPreTrainedModel):
@@ -74,29 +72,17 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
74
  def set_input_embeddings(self, value):
75
  self.llm.embed_tokens = value
76
 
77
- def vpm_forward_features(self, pixel_value):
78
- if isinstance(self.vpm, ModulesToSaveWrapper):
79
- if self.vpm.disable_adapters or (self.vpm.active_adapter not in self.vpm.modules_to_save):
80
- return self.vpm.original_module.forward_features(pixel_value)
81
- return self.vpm.modules_to_save[self.vpm.active_adapter].forward_features(pixel_value)
82
- else:
83
- return self.vpm.forward_features(pixel_value)
84
-
85
  def get_vision_embedding(self, pixel_values):
86
  res = []
87
- dtype = self.llm.lm_head.weight.dtype
88
- def process_each_pixel(pixel_value, dtype, config, vpm, resampler):
89
- H, W = pixel_value.shape[-2:]
90
- target_size = (math.ceil(H / config.patch_size), math.ceil(W / config.patch_size))
91
- vision_embedding = self.vpm_forward_features(pixel_value.unsqueeze(0).type(dtype))
92
-
93
- if hasattr(vpm, 'num_prefix_tokens') and vpm.num_prefix_tokens > 0:
94
- vision_embedding = vision_embedding[:, vpm.num_prefix_tokens:]
95
- return resampler(vision_embedding, target_size)
96
-
97
  for pixel_value in pixel_values:
98
- result = process_each_pixel(pixel_value, dtype, self.config, self.vpm, self.resampler)
99
- res.append(result)
 
 
 
 
 
100
  return torch.vstack(res)
101
 
102
  def get_vllm_embedding(self, data):
@@ -107,8 +93,8 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
107
  if len(pixel_values) > 0:
108
  vision_hidden_states.append(self.get_vision_embedding(pixel_values))
109
  elif self.training:
110
- dtype = self.llm.lm_head.weight.dtype
111
- device = self.llm.lm_head.weight.device
112
  dummy_image = torch.zeros(
113
  (1, 3, 224, 224), device=device, dtype=dtype
114
  )
@@ -333,21 +319,24 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
333
  content = msg["content"]
334
  assert role in ["user", "assistant"]
335
  if i == 0:
336
- assert role == "user", "The role of first msg should be user"
337
- if self.config.slice_mode:
338
- images, final_placeholder = self.get_slice_image_placeholder(
339
- image, tokenizer
340
- )
341
- content = final_placeholder + "\n" + content
342
  else:
343
- images = [image]
344
- content = (
345
- tokenizer.im_start
346
- + tokenizer.unk_token * self.config.query_num
347
- + tokenizer.im_end
348
- + "\n"
349
- + content
350
- )
 
 
 
 
 
 
 
351
  prompt += "<用户>" if role == "user" else "<AI>"
352
  prompt += content
353
  prompt += "<AI>"
@@ -388,8 +377,6 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
388
 
389
  return answer, context, generation_config
390
 
391
-
392
-
393
 
394
  class LlamaTokenizerWrapper(LlamaTokenizer):
395
  def __init__(self, **kwargs):
 
1
  import math
2
+ from typing import List, Optional
3
  import json
4
  import timm
5
  import torch
 
8
  from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
9
  from torchvision import transforms
10
  from transformers import LlamaTokenizer
11
+
12
  from .configuration_minicpm import MiniCPMVConfig
13
  from .modeling_minicpm import MiniCPMForCausalLM, MiniCPMPreTrainedModel
14
  from .resampler import Resampler
 
 
 
15
 
16
 
17
  class MiniCPMVPreTrainedModel(MiniCPMPreTrainedModel):
 
72
  def set_input_embeddings(self, value):
73
  self.llm.embed_tokens = value
74
 
 
 
 
 
 
 
 
 
75
  def get_vision_embedding(self, pixel_values):
76
  res = []
77
+ dtype = self.vpm.pos_embed.data.dtype
 
 
 
 
 
 
 
 
 
78
  for pixel_value in pixel_values:
79
+ H, W = pixel_value.shape[-2:]
80
+ tgt_size = (
81
+ math.ceil(H / self.vpm.patch_embed.patch_size[0]), math.ceil(W / self.vpm.patch_embed.patch_size[0]))
82
+ vision_embedding = self.vpm.forward_features(pixel_value.unsqueeze(0).type(dtype))
83
+ if hasattr(self.vpm, 'num_prefix_tokens') and self.vpm.num_prefix_tokens > 0:
84
+ vision_embedding = vision_embedding[:, self.vpm.num_prefix_tokens:]
85
+ res.append(self.resampler(vision_embedding, tgt_size))
86
  return torch.vstack(res)
87
 
88
  def get_vllm_embedding(self, data):
 
93
  if len(pixel_values) > 0:
94
  vision_hidden_states.append(self.get_vision_embedding(pixel_values))
95
  elif self.training:
96
+ dtype = self.vpm.pos_embed.data.dtype
97
+ device = self.vpm.pos_embed.data.device
98
  dummy_image = torch.zeros(
99
  (1, 3, 224, 224), device=device, dtype=dtype
100
  )
 
319
  content = msg["content"]
320
  assert role in ["user", "assistant"]
321
  if i == 0:
322
+ if image is None:
323
+ images = []
 
 
 
 
324
  else:
325
+ assert role == "user", "The role of first msg should be user"
326
+ if self.config.slice_mode:
327
+ images, final_placeholder = self.get_slice_image_placeholder(
328
+ image, tokenizer
329
+ )
330
+ content = final_placeholder + "\n" + content
331
+ else:
332
+ images = [image]
333
+ content = (
334
+ tokenizer.im_start
335
+ + tokenizer.unk_token * self.config.query_num
336
+ + tokenizer.im_end
337
+ + "\n"
338
+ + content
339
+ )
340
  prompt += "<用户>" if role == "user" else "<AI>"
341
  prompt += content
342
  prompt += "<AI>"
 
377
 
378
  return answer, context, generation_config
379
 
 
 
380
 
381
  class LlamaTokenizerWrapper(LlamaTokenizer):
382
  def __init__(self, **kwargs):
resampler.py CHANGED
@@ -19,20 +19,6 @@ from torch.nn.init import trunc_normal_
19
  from torchvision import transforms
20
  from torchvision.transforms import InterpolationMode
21
 
22
- from functools import partial
23
- import numpy as np
24
- import warnings
25
- from typing import Optional, Tuple
26
- import torch
27
- from torch import nn
28
- from torch import Tensor
29
- import torch.nn.functional as F
30
- from torch.nn.functional import *
31
- from torch.nn.modules.activation import *
32
- from torch.nn.init import trunc_normal_
33
- from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
34
- from transformers import PreTrainedModel
35
- from transformers.integrations import is_deepspeed_zero3_enabled
36
  def get_abs_pos(abs_pos, tgt_size):
37
  # abs_pos: L, C
38
  # tgt_size: (H, W)
@@ -131,20 +117,24 @@ class Resampler(nn.Module):
131
  self.pos_embed = nn.Parameter(
132
  torch.from_numpy(get_2d_sincos_pos_embed(embed_dim, grid_size)).float()
133
  ).requires_grad_(False)
 
134
  self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim))
 
135
 
136
  if kv_dim is not None and kv_dim != embed_dim:
137
  self.kv_proj = nn.Linear(kv_dim, embed_dim, bias=False)
138
  else:
139
  self.kv_proj = nn.Identity()
140
 
141
- self.attn = MultiheadAttention(embed_dim, num_heads)
142
  self.ln_q = norm_layer(embed_dim)
143
  self.ln_kv = norm_layer(embed_dim)
144
 
145
  self.ln_post = norm_layer(embed_dim)
146
  self.proj = nn.Parameter((embed_dim ** -0.5) * torch.randn(embed_dim, embed_dim))
147
 
 
 
148
  def _init_weights(self, m):
149
  if isinstance(m, nn.Linear):
150
  trunc_normal_(m.weight, std=.02)
@@ -159,667 +149,22 @@ class Resampler(nn.Module):
159
  pos_embed = torch.Tensor(get_2d_sincos_pos_embed(self.embed_dim, tgt_size)).float().to(device=x.device, dtype=x.dtype)
160
  else:
161
  pos_embed = get_abs_pos(self.pos_embed, tgt_size)
162
-
163
  x = self.kv_proj(x)
164
  x = self.ln_kv(x).permute(1, 0, 2)
165
 
166
  N = x.shape[1]
167
  q = self.ln_q(self.query)
168
-
169
  out = self.attn(
170
  self._repeat(q, N) + self.pos_embed.unsqueeze(1),
171
  x + pos_embed.unsqueeze(1),
172
  x,
173
  attn_mask=attn_mask)[0]
174
  x = out.permute(1, 0, 2)
 
175
  x = self.ln_post(x)
176
  x = x @ self.proj
177
  return x
178
 
179
  def _repeat(self, query, N: int):
180
  return query.unsqueeze(1).repeat(1, N, 1)
181
-
182
-
183
-
184
- class MultiheadAttention(nn.MultiheadAttention):
185
- def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False,
186
- add_zero_attn=False, kdim=None, vdim=None, batch_first=False, device=None, dtype=None):
187
- super().__init__(embed_dim, num_heads, dropout, bias, add_bias_kv, add_zero_attn, kdim, vdim, batch_first, device, dtype)
188
-
189
- # rewrite out_proj layer,with nn.Linear
190
- self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias,)
191
-
192
- def forward(
193
- self,
194
- query: Tensor,
195
- key: Tensor,
196
- value: Tensor,
197
- key_padding_mask: Optional[Tensor] = None,
198
- need_weights: bool = True,
199
- attn_mask: Optional[Tensor] = None,
200
- average_attn_weights: bool = True,
201
- is_causal : bool = False) -> Tuple[Tensor, Optional[Tensor]]:
202
- why_not_fast_path = ''
203
- if ((attn_mask is not None and torch.is_floating_point(attn_mask))
204
- or (key_padding_mask is not None) and torch.is_floating_point(key_padding_mask)):
205
- why_not_fast_path = "floating-point masks are not supported for fast path."
206
-
207
- is_batched = query.dim() == 3
208
-
209
- key_padding_mask = F._canonical_mask(
210
- mask=key_padding_mask,
211
- mask_name="key_padding_mask",
212
- other_type=F._none_or_dtype(attn_mask),
213
- other_name="attn_mask",
214
- target_type=query.dtype
215
- )
216
- # _canonical_mask
217
- attn_mask = F._canonical_mask(
218
- mask=attn_mask,
219
- mask_name="attn_mask",
220
- other_type=None,
221
- other_name="",
222
- target_type=query.dtype,
223
- check_other=False,
224
- )
225
-
226
-
227
- if not is_batched:
228
- why_not_fast_path = f"input not batched; expected query.dim() of 3 but got {query.dim()}"
229
- elif query is not key or key is not value:
230
- # When lifting this restriction, don't forget to either
231
- # enforce that the dtypes all match or test cases where
232
- # they don't!
233
- why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)"
234
- elif self.in_proj_bias is not None and query.dtype != self.in_proj_bias.dtype:
235
- why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match"
236
- elif self.in_proj_weight is None:
237
- why_not_fast_path = "in_proj_weight was None"
238
- elif query.dtype != self.in_proj_weight.dtype:
239
- # this case will fail anyway, but at least they'll get a useful error message.
240
- why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match"
241
- elif self.training:
242
- why_not_fast_path = "training is enabled"
243
- elif (self.num_heads % 2) != 0:
244
- why_not_fast_path = "self.num_heads is not even"
245
- elif not self.batch_first:
246
- why_not_fast_path = "batch_first was not True"
247
- elif self.bias_k is not None:
248
- why_not_fast_path = "self.bias_k was not None"
249
- elif self.bias_v is not None:
250
- why_not_fast_path = "self.bias_v was not None"
251
- elif self.add_zero_attn:
252
- why_not_fast_path = "add_zero_attn was enabled"
253
- elif not self._qkv_same_embed_dim:
254
- why_not_fast_path = "_qkv_same_embed_dim was not True"
255
- elif query.is_nested and (key_padding_mask is not None or attn_mask is not None):
256
- why_not_fast_path = "supplying both src_key_padding_mask and src_mask at the same time \
257
- is not supported with NestedTensor input"
258
- elif torch.is_autocast_enabled():
259
- why_not_fast_path = "autocast is enabled"
260
-
261
- if not why_not_fast_path:
262
- tensor_args = (
263
- query,
264
- key,
265
- value,
266
- self.in_proj_weight,
267
- self.in_proj_bias,
268
- self.out_proj.weight,
269
- self.out_proj.bias,
270
- )
271
- # We have to use list comprehensions below because TorchScript does not support
272
- # generator expressions.
273
- if torch.overrides.has_torch_function(tensor_args):
274
- why_not_fast_path = "some Tensor argument has_torch_function"
275
- elif _is_make_fx_tracing():
276
- why_not_fast_path = "we are running make_fx tracing"
277
- elif not all(_check_arg_device(x) for x in tensor_args):
278
- why_not_fast_path = ("some Tensor argument's device is neither one of "
279
- f"cpu, cuda or {torch.utils.backend_registration._privateuse1_backend_name}")
280
- elif torch.is_grad_enabled() and any(_arg_requires_grad(x) for x in tensor_args):
281
- why_not_fast_path = ("grad is enabled and at least one of query or the "
282
- "input/output projection weights or biases requires_grad")
283
- if not why_not_fast_path:
284
- merged_mask, mask_type = self.merge_masks(attn_mask, key_padding_mask, query)
285
-
286
- if self.in_proj_bias is not None and self.in_proj_weight is not None:
287
- return torch._native_multi_head_attention(
288
- query,
289
- key,
290
- value,
291
- self.embed_dim,
292
- self.num_heads,
293
- self.in_proj_weight,
294
- self.in_proj_bias,
295
- self.out_proj.weight,
296
- self.out_proj.bias,
297
- merged_mask,
298
- need_weights,
299
- average_attn_weights,
300
- mask_type)
301
-
302
- any_nested = query.is_nested or key.is_nested or value.is_nested
303
- assert not any_nested, ("MultiheadAttention does not support NestedTensor outside of its fast path. " +
304
- f"The fast path was not hit because {why_not_fast_path}")
305
-
306
- if self.batch_first and is_batched:
307
- # make sure that the transpose op does not affect the "is" property
308
- if key is value:
309
- if query is key:
310
- query = key = value = query.transpose(1, 0)
311
- else:
312
- query, key = (x.transpose(1, 0) for x in (query, key))
313
- value = key
314
- else:
315
- query, key, value = (x.transpose(1, 0) for x in (query, key, value))
316
-
317
- if not self._qkv_same_embed_dim:
318
- attn_output, attn_output_weights = self.multi_head_attention_forward(
319
- query, key, value, self.embed_dim, self.num_heads,
320
- self.in_proj_weight, self.in_proj_bias,
321
- self.bias_k, self.bias_v, self.add_zero_attn,
322
- self.dropout, self.out_proj.weight, self.out_proj.bias,
323
- training=self.training,
324
- key_padding_mask=key_padding_mask, need_weights=need_weights,
325
- attn_mask=attn_mask,
326
- use_separate_proj_weight=True,
327
- q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
328
- v_proj_weight=self.v_proj_weight,
329
- average_attn_weights=average_attn_weights,
330
- is_causal=is_causal)
331
- else:
332
- attn_output, attn_output_weights = self.multi_head_attention_forward(
333
- query, key, value, self.embed_dim, self.num_heads,
334
- self.in_proj_weight, self.in_proj_bias,
335
- self.bias_k, self.bias_v, self.add_zero_attn,
336
- self.dropout, self.out_proj.weight, self.out_proj.bias,
337
- training=self.training,
338
- key_padding_mask=key_padding_mask,
339
- need_weights=need_weights,
340
- attn_mask=attn_mask,
341
- average_attn_weights=average_attn_weights,
342
- is_causal=is_causal)
343
- if self.batch_first and is_batched:
344
- return attn_output.transpose(1, 0), attn_output_weights
345
- else:
346
- return attn_output, attn_output_weights
347
-
348
- def multi_head_attention_forward(
349
- self,
350
- query: Tensor,
351
- key: Tensor,
352
- value: Tensor,
353
- embed_dim_to_check: int,
354
- num_heads: int,
355
- in_proj_weight: Optional[Tensor],
356
- in_proj_bias: Optional[Tensor],
357
- bias_k: Optional[Tensor],
358
- bias_v: Optional[Tensor],
359
- add_zero_attn: bool,
360
- dropout_p: float,
361
- out_proj_weight: Tensor,
362
- out_proj_bias: Optional[Tensor],
363
- training: bool = True,
364
- key_padding_mask: Optional[Tensor] = None,
365
- need_weights: bool = True,
366
- attn_mask: Optional[Tensor] = None,
367
- use_separate_proj_weight: bool = False,
368
- q_proj_weight: Optional[Tensor] = None,
369
- k_proj_weight: Optional[Tensor] = None,
370
- v_proj_weight: Optional[Tensor] = None,
371
- static_k: Optional[Tensor] = None,
372
- static_v: Optional[Tensor] = None,
373
- average_attn_weights: bool = True,
374
- is_causal: bool = False,
375
- ) -> Tuple[Tensor, Optional[Tensor]]:
376
- tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v, out_proj_weight, out_proj_bias)
377
- if has_torch_function(tens_ops):
378
- return handle_torch_function(
379
- multi_head_attention_forward,
380
- tens_ops,
381
- query,
382
- key,
383
- value,
384
- embed_dim_to_check,
385
- num_heads,
386
- in_proj_weight,
387
- in_proj_bias,
388
- bias_k,
389
- bias_v,
390
- add_zero_attn,
391
- dropout_p,
392
- out_proj_weight,
393
- out_proj_bias,
394
- training=training,
395
- key_padding_mask=key_padding_mask,
396
- need_weights=need_weights,
397
- attn_mask=attn_mask,
398
- is_causal=is_causal,
399
- use_separate_proj_weight=use_separate_proj_weight,
400
- q_proj_weight=q_proj_weight,
401
- k_proj_weight=k_proj_weight,
402
- v_proj_weight=v_proj_weight,
403
- static_k=static_k,
404
- static_v=static_v,
405
- average_attn_weights=average_attn_weights,
406
- )
407
-
408
- is_batched = _mha_shape_check(query, key, value, key_padding_mask, attn_mask, num_heads)
409
-
410
- # For unbatched input, we unsqueeze at the expected batch-dim to pretend that the input
411
- # is batched, run the computation and before returning squeeze the
412
- # batch dimension so that the output doesn't carry this temporary batch dimension.
413
- if not is_batched:
414
- # unsqueeze if the input is unbatched
415
- query = query.unsqueeze(1)
416
- key = key.unsqueeze(1)
417
- value = value.unsqueeze(1)
418
- if key_padding_mask is not None:
419
- key_padding_mask = key_padding_mask.unsqueeze(0)
420
-
421
- # set up shape vars
422
- tgt_len, bsz, embed_dim = query.shape
423
- src_len, _, _ = key.shape
424
-
425
- key_padding_mask = _canonical_mask(
426
- mask=key_padding_mask,
427
- mask_name="key_padding_mask",
428
- other_type=_none_or_dtype(attn_mask),
429
- other_name="attn_mask",
430
- target_type=query.dtype
431
- )
432
-
433
- if is_causal and attn_mask is None:
434
- raise RuntimeError(
435
- "Need attn_mask if specifying the is_causal hint. "
436
- "You may use the Transformer module method "
437
- "`generate_square_subsequent_mask` to create this mask."
438
- )
439
-
440
- if is_causal and key_padding_mask is None and not need_weights:
441
- # when we have a kpm or need weights, we need attn_mask
442
- # Otherwise, we use the is_causal hint go as is_causal
443
- # indicator to SDPA.
444
- attn_mask = None
445
- else:
446
- attn_mask = _canonical_mask(
447
- mask=attn_mask,
448
- mask_name="attn_mask",
449
- other_type=None,
450
- other_name="",
451
- target_type=query.dtype,
452
- check_other=False,
453
- )
454
-
455
- if key_padding_mask is not None:
456
- # We have the attn_mask, and use that to merge kpm into it.
457
- # Turn off use of is_causal hint, as the merged mask is no
458
- # longer causal.
459
- is_causal = False
460
-
461
- assert embed_dim == embed_dim_to_check, \
462
- f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}"
463
- if isinstance(embed_dim, torch.Tensor):
464
- # embed_dim can be a tensor when JIT tracing
465
- head_dim = embed_dim.div(num_heads, rounding_mode='trunc')
466
- else:
467
- head_dim = embed_dim // num_heads
468
- assert head_dim * num_heads == embed_dim, f"embed_dim {embed_dim} not divisible by num_heads {num_heads}"
469
- if use_separate_proj_weight:
470
- # allow MHA to have different embedding dimensions when separate projection weights are used
471
- assert key.shape[:2] == value.shape[:2], \
472
- f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}"
473
- else:
474
- assert key.shape == value.shape, f"key shape {key.shape} does not match value shape {value.shape}"
475
-
476
- #
477
- # compute in-projection
478
- #
479
-
480
- if not use_separate_proj_weight:
481
- assert in_proj_weight is not None, "use_separate_proj_weight is False but in_proj_weight is None"
482
- q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias)
483
- else:
484
- assert q_proj_weight is not None, "use_separate_proj_weight is True but q_proj_weight is None"
485
- assert k_proj_weight is not None, "use_separate_proj_weight is True but k_proj_weight is None"
486
- assert v_proj_weight is not None, "use_separate_proj_weight is True but v_proj_weight is None"
487
- if in_proj_bias is None:
488
- b_q = b_k = b_v = None
489
- else:
490
- b_q, b_k, b_v = in_proj_bias.chunk(3)
491
- q, k, v = _in_projection(query, key, value, q_proj_weight, k_proj_weight, v_proj_weight, b_q, b_k, b_v)
492
-
493
- # prep attention mask
494
-
495
- if attn_mask is not None:
496
- # ensure attn_mask's dim is 3
497
- if attn_mask.dim() == 2:
498
- correct_2d_size = (tgt_len, src_len)
499
- if attn_mask.shape != correct_2d_size:
500
- raise RuntimeError(f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}.")
501
- attn_mask = attn_mask.unsqueeze(0)
502
- elif attn_mask.dim() == 3:
503
- correct_3d_size = (bsz * num_heads, tgt_len, src_len)
504
- if attn_mask.shape != correct_3d_size:
505
- raise RuntimeError(f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}.")
506
- else:
507
- raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported")
508
-
509
- # add bias along batch dimension (currently second)
510
- if bias_k is not None and bias_v is not None:
511
- assert static_k is None, "bias cannot be added to static key."
512
- assert static_v is None, "bias cannot be added to static value."
513
- k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
514
- v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
515
- if attn_mask is not None:
516
- attn_mask = pad(attn_mask, (0, 1))
517
- if key_padding_mask is not None:
518
- key_padding_mask = pad(key_padding_mask, (0, 1))
519
- else:
520
- assert bias_k is None
521
- assert bias_v is None
522
-
523
- #
524
- # reshape q, k, v for multihead attention and make em batch first
525
- #
526
- q = q.view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
527
- if static_k is None:
528
- k = k.view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
529
- else:
530
- # TODO finish disentangling control flow so we don't do in-projections when statics are passed
531
- assert static_k.size(0) == bsz * num_heads, \
532
- f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}"
533
- assert static_k.size(2) == head_dim, \
534
- f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}"
535
- k = static_k
536
- if static_v is None:
537
- v = v.view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
538
- else:
539
- # TODO finish disentangling control flow so we don't do in-projections when statics are passed
540
- assert static_v.size(0) == bsz * num_heads, \
541
- f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}"
542
- assert static_v.size(2) == head_dim, \
543
- f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}"
544
- v = static_v
545
-
546
- # add zero attention along batch dimension (now first)
547
- if add_zero_attn:
548
- zero_attn_shape = (bsz * num_heads, 1, head_dim)
549
- k = torch.cat([k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1)
550
- v = torch.cat([v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1)
551
- if attn_mask is not None:
552
- attn_mask = pad(attn_mask, (0, 1))
553
- if key_padding_mask is not None:
554
- key_padding_mask = pad(key_padding_mask, (0, 1))
555
-
556
- # update source sequence length after adjustments
557
- src_len = k.size(1)
558
-
559
- # merge key padding and attention masks
560
- if key_padding_mask is not None:
561
- assert key_padding_mask.shape == (bsz, src_len), \
562
- f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}"
563
- key_padding_mask = key_padding_mask.view(bsz, 1, 1, src_len). \
564
- expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, src_len)
565
- if attn_mask is None:
566
- attn_mask = key_padding_mask
567
- else:
568
- attn_mask = attn_mask + key_padding_mask
569
-
570
- # adjust dropout probability
571
- if not training:
572
- dropout_p = 0.0
573
-
574
- #
575
- # (deep breath) calculate attention and out projection
576
- #
577
-
578
- if need_weights:
579
- B, Nt, E = q.shape
580
- q_scaled = q / math.sqrt(E)
581
-
582
- assert not (is_causal and attn_mask is None), "FIXME: is_causal not implemented for need_weights"
583
-
584
- if attn_mask is not None:
585
- attn_output_weights = torch.baddbmm(attn_mask, q_scaled, k.transpose(-2, -1))
586
- else:
587
- attn_output_weights = torch.bmm(q_scaled, k.transpose(-2, -1))
588
- attn_output_weights = softmax(attn_output_weights, dim=-1)
589
- if dropout_p > 0.0:
590
- attn_output_weights = dropout(attn_output_weights, p=dropout_p)
591
-
592
- attn_output = torch.bmm(attn_output_weights, v)
593
-
594
- attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)
595
- attn_output = self.out_proj(attn_output)
596
-
597
- attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
598
-
599
- # optionally average attention weights over heads
600
- attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
601
- if average_attn_weights:
602
- attn_output_weights = attn_output_weights.mean(dim=1)
603
-
604
- if not is_batched:
605
- # squeeze the output if input was unbatched
606
- attn_output = attn_output.squeeze(1)
607
- attn_output_weights = attn_output_weights.squeeze(0)
608
- return attn_output, attn_output_weights
609
- else:
610
- # attn_mask can be either (L,S) or (N*num_heads, L, S)
611
- # if attn_mask's shape is (1, L, S) we need to unsqueeze to (1, 1, L, S)
612
- # in order to match the input for SDPA of (N, num_heads, L, S)
613
- if attn_mask is not None:
614
- if attn_mask.size(0) == 1 and attn_mask.dim() == 3:
615
- attn_mask = attn_mask.unsqueeze(0)
616
- else:
617
- attn_mask = attn_mask.view(bsz, num_heads, -1, src_len)
618
-
619
- q = q.view(bsz, num_heads, tgt_len, head_dim)
620
- k = k.view(bsz, num_heads, src_len, head_dim)
621
- v = v.view(bsz, num_heads, src_len, head_dim)
622
-
623
- attn_output = F.scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
624
- attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)
625
-
626
- attn_output = self.out_proj(attn_output)
627
-
628
- attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
629
- if not is_batched:
630
- # squeeze the output if input was unbatched
631
- attn_output = attn_output.squeeze(1)
632
- return attn_output, None
633
-
634
-
635
- def _mha_shape_check(query: Tensor, key: Tensor, value: Tensor,
636
- key_padding_mask: Optional[Tensor], attn_mask: Optional[Tensor], num_heads: int):
637
- # Verifies the expected shape for `query, `key`, `value`, `key_padding_mask` and `attn_mask`
638
- # and returns if the input is batched or not.
639
- # Raises an error if `query` is not 2-D (unbatched) or 3-D (batched) tensor.
640
-
641
- # Shape check.
642
- if query.dim() == 3:
643
- # Batched Inputs
644
- is_batched = True
645
- assert key.dim() == 3 and value.dim() == 3, \
646
- ("For batched (3-D) `query`, expected `key` and `value` to be 3-D"
647
- f" but found {key.dim()}-D and {value.dim()}-D tensors respectively")
648
- if key_padding_mask is not None:
649
- assert key_padding_mask.dim() == 2, \
650
- ("For batched (3-D) `query`, expected `key_padding_mask` to be `None` or 2-D"
651
- f" but found {key_padding_mask.dim()}-D tensor instead")
652
- if attn_mask is not None:
653
- assert attn_mask.dim() in (2, 3), \
654
- ("For batched (3-D) `query`, expected `attn_mask` to be `None`, 2-D or 3-D"
655
- f" but found {attn_mask.dim()}-D tensor instead")
656
- elif query.dim() == 2:
657
- # Unbatched Inputs
658
- is_batched = False
659
- assert key.dim() == 2 and value.dim() == 2, \
660
- ("For unbatched (2-D) `query`, expected `key` and `value` to be 2-D"
661
- f" but found {key.dim()}-D and {value.dim()}-D tensors respectively")
662
-
663
- if key_padding_mask is not None:
664
- assert key_padding_mask.dim() == 1, \
665
- ("For unbatched (2-D) `query`, expected `key_padding_mask` to be `None` or 1-D"
666
- f" but found {key_padding_mask.dim()}-D tensor instead")
667
-
668
- if attn_mask is not None:
669
- assert attn_mask.dim() in (2, 3), \
670
- ("For unbatched (2-D) `query`, expected `attn_mask` to be `None`, 2-D or 3-D"
671
- f" but found {attn_mask.dim()}-D tensor instead")
672
- if attn_mask.dim() == 3:
673
- expected_shape = (num_heads, query.shape[0], key.shape[0])
674
- assert attn_mask.shape == expected_shape, \
675
- (f"Expected `attn_mask` shape to be {expected_shape} but got {attn_mask.shape}")
676
- else:
677
- raise AssertionError(
678
- f"query should be unbatched 2D or batched 3D tensor but received {query.dim()}-D query tensor")
679
-
680
- return is_batched
681
-
682
-
683
- def _canonical_mask(
684
- mask: Optional[Tensor],
685
- mask_name: str,
686
- other_type: Optional[DType],
687
- other_name: str,
688
- target_type: DType,
689
- check_other: bool = True,
690
- ) -> Optional[Tensor]:
691
-
692
- if mask is not None:
693
- _mask_dtype = mask.dtype
694
- _mask_is_float = torch.is_floating_point(mask)
695
- if _mask_dtype != torch.bool and not _mask_is_float:
696
- raise AssertionError(
697
- f"only bool and floating types of {mask_name} are supported")
698
- if check_other and other_type is not None:
699
- if _mask_dtype != other_type:
700
- warnings.warn(
701
- f"Support for mismatched {mask_name} and {other_name} "
702
- "is deprecated. Use same type for both instead."
703
- )
704
- if not _mask_is_float:
705
- mask = (
706
- torch.zeros_like(mask, dtype=target_type)
707
- .masked_fill_(mask, float("-inf"))
708
- )
709
- return mask
710
-
711
-
712
- def _none_or_dtype(input: Optional[Tensor]) -> Optional[DType]:
713
- if input is None:
714
- return None
715
- elif isinstance(input, torch.Tensor):
716
- return input.dtype
717
- raise RuntimeError("input to _none_or_dtype() must be None or torch.Tensor")
718
-
719
- def _in_projection_packed(
720
- q: Tensor,
721
- k: Tensor,
722
- v: Tensor,
723
- w: Tensor,
724
- b: Optional[Tensor] = None,
725
- ) -> List[Tensor]:
726
- r"""
727
- Performs the in-projection step of the attention operation, using packed weights.
728
- Output is a triple containing projection tensors for query, key and value.
729
- Args:
730
- q, k, v: query, key and value tensors to be projected. For self-attention,
731
- these are typically the same tensor; for encoder-decoder attention,
732
- k and v are typically the same tensor. (We take advantage of these
733
- identities for performance if they are present.) Regardless, q, k and v
734
- must share a common embedding dimension; otherwise their shapes may vary.
735
- w: projection weights for q, k and v, packed into a single tensor. Weights
736
- are packed along dimension 0, in q, k, v order.
737
- b: optional projection biases for q, k and v, packed into a single tensor
738
- in q, k, v order.
739
- Shape:
740
- Inputs:
741
- - q: :math:`(..., E)` where E is the embedding dimension
742
- - k: :math:`(..., E)` where E is the embedding dimension
743
- - v: :math:`(..., E)` where E is the embedding dimension
744
- - w: :math:`(E * 3, E)` where E is the embedding dimension
745
- - b: :math:`E * 3` where E is the embedding dimension
746
- Output:
747
- - in output list :math:`[q', k', v']`, each output tensor will have the
748
- same shape as the corresponding input tensor.
749
- """
750
- E = q.size(-1)
751
- if k is v:
752
- if q is k:
753
- # self-attention
754
- proj = linear(q, w, b)
755
- # reshape to 3, E and not E, 3 is deliberate for better memory coalescing and keeping same order as chunk()
756
- proj = proj.unflatten(-1, (3, E)).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous()
757
- return proj[0], proj[1], proj[2]
758
- else:
759
- # encoder-decoder attention
760
- w_q, w_kv = w.split([E, E * 2])
761
- if b is None:
762
- b_q = b_kv = None
763
- else:
764
- b_q, b_kv = b.split([E, E * 2])
765
- q_proj = linear(q, w_q, b_q)
766
- kv_proj = linear(k, w_kv, b_kv)
767
- # reshape to 2, E and not E, 2 is deliberate for better memory coalescing and keeping same order as chunk()
768
- kv_proj = kv_proj.unflatten(-1, (2, E)).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous()
769
- return (q_proj, kv_proj[0], kv_proj[1])
770
- else:
771
- w_q, w_k, w_v = w.chunk(3)
772
- if b is None:
773
- b_q = b_k = b_v = None
774
- else:
775
- b_q, b_k, b_v = b.chunk(3)
776
- return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)
777
-
778
-
779
- def _in_projection(
780
- q: Tensor,
781
- k: Tensor,
782
- v: Tensor,
783
- w_q: Tensor,
784
- w_k: Tensor,
785
- w_v: Tensor,
786
- b_q: Optional[Tensor] = None,
787
- b_k: Optional[Tensor] = None,
788
- b_v: Optional[Tensor] = None,
789
- ) -> Tuple[Tensor, Tensor, Tensor]:
790
- r"""
791
- Performs the in-projection step of the attention operation. This is simply
792
- a triple of linear projections, with shape constraints on the weights which
793
- ensure embedding dimension uniformity in the projected outputs.
794
- Output is a triple containing projection tensors for query, key and value.
795
- Args:
796
- q, k, v: query, key and value tensors to be projected.
797
- w_q, w_k, w_v: weights for q, k and v, respectively.
798
- b_q, b_k, b_v: optional biases for q, k and v, respectively.
799
- Shape:
800
- Inputs:
801
- - q: :math:`(Qdims..., Eq)` where Eq is the query embedding dimension and Qdims are any
802
- number of leading dimensions.
803
- - k: :math:`(Kdims..., Ek)` where Ek is the key embedding dimension and Kdims are any
804
- number of leading dimensions.
805
- - v: :math:`(Vdims..., Ev)` where Ev is the value embedding dimension and Vdims are any
806
- number of leading dimensions.
807
- - w_q: :math:`(Eq, Eq)`
808
- - w_k: :math:`(Eq, Ek)`
809
- - w_v: :math:`(Eq, Ev)`
810
- - b_q: :math:`(Eq)`
811
- - b_k: :math:`(Eq)`
812
- - b_v: :math:`(Eq)`
813
- Output: in output triple :math:`(q', k', v')`,
814
- - q': :math:`[Qdims..., Eq]`
815
- - k': :math:`[Kdims..., Eq]`
816
- - v': :math:`[Vdims..., Eq]`
817
- """
818
- Eq, Ek, Ev = q.size(-1), k.size(-1), v.size(-1)
819
- assert w_q.shape == (Eq, Eq), f"expecting query weights shape of {(Eq, Eq)}, but got {w_q.shape}"
820
- assert w_k.shape == (Eq, Ek), f"expecting key weights shape of {(Eq, Ek)}, but got {w_k.shape}"
821
- assert w_v.shape == (Eq, Ev), f"expecting value weights shape of {(Eq, Ev)}, but got {w_v.shape}"
822
- assert b_q is None or b_q.shape == (Eq,), f"expecting query bias shape of {(Eq,)}, but got {b_q.shape}"
823
- assert b_k is None or b_k.shape == (Eq,), f"expecting key bias shape of {(Eq,)}, but got {b_k.shape}"
824
- assert b_v is None or b_v.shape == (Eq,), f"expecting value bias shape of {(Eq,)}, but got {b_v.shape}"
825
- return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)
 
19
  from torchvision import transforms
20
  from torchvision.transforms import InterpolationMode
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  def get_abs_pos(abs_pos, tgt_size):
23
  # abs_pos: L, C
24
  # tgt_size: (H, W)
 
117
  self.pos_embed = nn.Parameter(
118
  torch.from_numpy(get_2d_sincos_pos_embed(embed_dim, grid_size)).float()
119
  ).requires_grad_(False)
120
+
121
  self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim))
122
+ trunc_normal_(self.query, std=.02)
123
 
124
  if kv_dim is not None and kv_dim != embed_dim:
125
  self.kv_proj = nn.Linear(kv_dim, embed_dim, bias=False)
126
  else:
127
  self.kv_proj = nn.Identity()
128
 
129
+ self.attn = nn.MultiheadAttention(embed_dim, num_heads)
130
  self.ln_q = norm_layer(embed_dim)
131
  self.ln_kv = norm_layer(embed_dim)
132
 
133
  self.ln_post = norm_layer(embed_dim)
134
  self.proj = nn.Parameter((embed_dim ** -0.5) * torch.randn(embed_dim, embed_dim))
135
 
136
+ self.apply(self._init_weights)
137
+
138
  def _init_weights(self, m):
139
  if isinstance(m, nn.Linear):
140
  trunc_normal_(m.weight, std=.02)
 
149
  pos_embed = torch.Tensor(get_2d_sincos_pos_embed(self.embed_dim, tgt_size)).float().to(device=x.device, dtype=x.dtype)
150
  else:
151
  pos_embed = get_abs_pos(self.pos_embed, tgt_size)
152
+
153
  x = self.kv_proj(x)
154
  x = self.ln_kv(x).permute(1, 0, 2)
155
 
156
  N = x.shape[1]
157
  q = self.ln_q(self.query)
 
158
  out = self.attn(
159
  self._repeat(q, N) + self.pos_embed.unsqueeze(1),
160
  x + pos_embed.unsqueeze(1),
161
  x,
162
  attn_mask=attn_mask)[0]
163
  x = out.permute(1, 0, 2)
164
+
165
  x = self.ln_post(x)
166
  x = x @ self.proj
167
  return x
168
 
169
  def _repeat(self, query, N: int):
170
  return query.unsqueeze(1).repeat(1, N, 1)