Video-Text-to-Text
Safetensors
mistral
ynhe commited on
Commit
1905130
·
verified ·
1 Parent(s): 0568fba

[Init] upload model

Browse files
config.json CHANGED
@@ -1,7 +1,6 @@
1
  {
2
- "_name_or_path": "/mnt/petrelfs/wangchenting/multimodalllm/logs/scripts/pt/1b_qformer_mistral/stage3_hd.sh_20240715_211017/checkpoint-last",
3
  "architectures": [
4
- "MultiModalLLM_PT"
5
  ],
6
  "attention_dropout": 0.0,
7
  "bos_token_id": 1,
@@ -11,8 +10,50 @@
11
  "initializer_range": 0.02,
12
  "intermediate_size": 14336,
13
  "max_position_embeddings": 32768,
14
- "model_config": null,
15
- "model_tokenizer": null,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  "model_type": "mistral",
17
  "num_attention_heads": 32,
18
  "num_hidden_layers": 32,
@@ -21,8 +62,8 @@
21
  "rope_theta": 1000000.0,
22
  "sliding_window": null,
23
  "tie_word_embeddings": false,
24
- "torch_dtype": "bfloat16",
25
- "transformers_version": "4.35.2",
26
  "use_cache": true,
27
  "vocab_size": 32768
28
  }
 
1
  {
 
2
  "architectures": [
3
+ "MistralModel"
4
  ],
5
  "attention_dropout": 0.0,
6
  "bos_token_id": 1,
 
10
  "initializer_range": 0.02,
11
  "intermediate_size": 14336,
12
  "max_position_embeddings": 32768,
13
+ "model_config": {
14
+ "bridge": {
15
+ "extra_num_query_token": 64,
16
+ "name": "qformer",
17
+ "num_query_token": 32,
18
+ "qformer_attention_probs_dropout_prob": 0.1,
19
+ "qformer_drop_path_rate": 0.2,
20
+ "qformer_hidden_dropout_prob": 0.1
21
+ },
22
+ "freeze_bridge": false,
23
+ "freeze_llm": false,
24
+ "freeze_vision_encoder": false,
25
+ "llm": {
26
+ "lora_alpha": 32,
27
+ "lora_dropout": 0.1,
28
+ "lora_r": 16,
29
+ "name": "mistral_7b",
30
+ "pretrained_llm_path": "mistralai/Mistral-7B-Instruct-v0.3",
31
+ "use_lora": true
32
+ },
33
+ "loss": {
34
+ "use_vision_regression_loss": false
35
+ },
36
+ "model_cls": "MultiModalLLM_PT",
37
+ "pretrained_paths": {},
38
+ "use_flash_attention": true,
39
+ "vision_encoder": {
40
+ "checkpoint_num": 48,
41
+ "d_model": 1408,
42
+ "encoder_embed_dim": 1408,
43
+ "img_size": 224,
44
+ "name": "internvideo2-1B",
45
+ "num_frames": 16,
46
+ "origin_num_frames": 4,
47
+ "patch_size": 14,
48
+ "pretrained": null,
49
+ "sep_image_video_pos_embed": true,
50
+ "tubelet_size": 1,
51
+ "use_checkpoint": true,
52
+ "vit_add_ln": true,
53
+ "x_vis_only": true,
54
+ "x_vis_return_idx": -2
55
+ }
56
+ },
57
  "model_type": "mistral",
58
  "num_attention_heads": 32,
59
  "num_hidden_layers": 32,
 
62
  "rope_theta": 1000000.0,
63
  "sliding_window": null,
64
  "tie_word_embeddings": false,
65
+ "torch_dtype": "float32",
66
+ "transformers_version": "4.38.0",
67
  "use_cache": true,
68
  "vocab_size": 32768
69
  }
flash_attention_class.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from einops import rearrange
5
+
6
+ from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
7
+ from flash_attn.bert_padding import unpad_input, pad_input
8
+
9
+
10
+ class FlashAttention(nn.Module):
11
+ """Implement the scaled dot product attention with softmax.
12
+ Arguments
13
+ ---------
14
+ softmax_scale: The temperature to use for the softmax attention.
15
+ (default: 1/sqrt(d_keys) where d_keys is computed at
16
+ runtime)
17
+ attention_dropout: The dropout rate to apply to the attention
18
+ (default: 0.0)
19
+ """
20
+
21
+ def __init__(self, softmax_scale=None, attention_dropout=0.0, device=None, dtype=None):
22
+ super().__init__()
23
+ self.softmax_scale = softmax_scale
24
+ self.dropout_p = attention_dropout
25
+
26
+ def forward(self, qkv, key_padding_mask=None, causal=False, cu_seqlens=None,
27
+ max_s=None, need_weights=False):
28
+ """Implements the multihead softmax attention.
29
+ Arguments
30
+ ---------
31
+ qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None
32
+ if unpadded: (nnz, 3, h, d)
33
+ key_padding_mask: a bool tensor of shape (B, S)
34
+ """
35
+ assert not need_weights
36
+ assert qkv.dtype in [torch.float16, torch.bfloat16]
37
+ assert qkv.is_cuda
38
+
39
+ if cu_seqlens is None:
40
+ batch_size = qkv.shape[0]
41
+ seqlen = qkv.shape[1]
42
+ if key_padding_mask is None:
43
+ qkv = rearrange(qkv, 'b s ... -> (b s) ...')
44
+ max_s = seqlen
45
+ cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
46
+ device=qkv.device)
47
+ output = flash_attn_varlen_qkvpacked_func(
48
+ qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
49
+ softmax_scale=self.softmax_scale, causal=causal
50
+ )
51
+ output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
52
+ else:
53
+ nheads = qkv.shape[-2]
54
+ x = rearrange(qkv, 'b s three h d -> b s (three h d)')
55
+ x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask)
56
+ x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads)
57
+ output_unpad = flash_attn_varlen_qkvpacked_func(
58
+ x_unpad, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
59
+ softmax_scale=self.softmax_scale, causal=causal
60
+ )
61
+ output = rearrange(pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'),
62
+ indices, batch_size, seqlen),
63
+ 'b s (h d) -> b s h d', h=nheads)
64
+ else:
65
+ assert max_s is not None
66
+ output = flash_attn_varlen_qkvpacked_func(
67
+ qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
68
+ softmax_scale=self.softmax_scale, causal=causal
69
+ )
70
+
71
+ return output, None
model.safetensors.index.json CHANGED
The diff for this file is too large to render. See raw diff
 
modeling_base.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import logging
3
+ import torch
4
+ import torch.utils.checkpoint
5
+ from torch import nn
6
+ from torch.nn import MSELoss
7
+ from .llm.llama_xformer import LlamaForCausalLM
8
+
9
+ from petrel_client.client import Client
10
+ from torch.cuda.amp import autocast as autocast
11
+
12
+ from .vision_encoder import pretrain_internvideo2_giant_patch14_224_clean, build_vit, interpolate_pos_embed_internvideo2_new
13
+ from .bridge import build_qformer, build_causal_qformer
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+ from transformers import LlamaTokenizer,AutoTokenizer,AutoModel,AutoModelForCausalLM,AutoProcessor
18
+ from transformers import AutoConfig, PreTrainedModel
19
+
20
+
21
+ def disabled_train(self, mode=True):
22
+ """Overwrite model.train with this function to make sure train/eval mode
23
+ does not change anymore."""
24
+ return self
25
+
26
+
27
+ def freeze_module(module):
28
+ for _, param in module.named_parameters():
29
+ param.requires_grad = False
30
+ module = module.eval()
31
+ module.train = disabled_train
32
+ return module
33
+
34
+
35
+ class LLMConfig(AutoConfig):
36
+ model_type = ""
37
+
38
+
39
+ class BaseMLLM(PreTrainedModel):
40
+ config_class = LLMConfig
41
+ def __init__(self, config):
42
+ # m_config = LLMConfig.from_pretrained('/mnt/petrelfs/share_data/likunchang/model/llm/internlm2-chat-20b', trust_remote_code=True)
43
+ # super().__init__(config)
44
+ self.model_config = config.model_config
45
+ config.model_config = None
46
+ super().__init__(config)
47
+ self.build_vision_encoder()
48
+ self.build_llm()
49
+ self.build_bridge()
50
+ self.build_loss()
51
+ self.load_pretrained_weights()
52
+ # NOTE place it after freeze llm
53
+ for n, p in self.named_parameters():
54
+ if p.requires_grad:
55
+ logger.info(f'{n} requires_grad')
56
+
57
+
58
+ def build_vision_encoder(self):
59
+ # load pretrained internvideo2-1b here, simplified as it receives no args
60
+ # note that we haven't load the internvideo pretrained version
61
+ if 'internvideo2' in self.model_config.vision_encoder.name.lower():
62
+ encoder_name = self.model_config.vision_encoder.name
63
+ logger.info(f"Build vision_encoder: {encoder_name}")
64
+ if encoder_name == 'internvideo2-1B':
65
+ self.vision_encoder = pretrain_internvideo2_giant_patch14_224_clean(self.model_config)
66
+ else:
67
+ raise ValueError(f"Not implemented: {encoder_name}")
68
+ elif 'vit' in self.model_config.vision_encoder.name.lower():
69
+ self.vision_encoder = build_vit(self.model_config)
70
+ else:
71
+ raise NotImplementedError(self.model_config.vision_encoder.name)
72
+
73
+ if self.model_config.vision_encoder.vit_add_ln:
74
+ self.vision_layernorm = nn.LayerNorm(self.model_config.vision_encoder.encoder_embed_dim, eps=1e-12)
75
+ else:
76
+ self.vision_layernorm = nn.Identity()
77
+
78
+ self.freeze_vision_encoder = self.model_config.get("freeze_vision_encoder", False)
79
+
80
+ if self.freeze_vision_encoder:
81
+ logger.info("freeze vision encoder")
82
+ freeze_module(self.vision_encoder)
83
+ freeze_module(self.vision_layernorm)
84
+
85
+
86
+ def build_bridge(self):
87
+ # ViT to LM: 1792 -> 6656 NOTE 768 is qformer dim
88
+ self.project_up = nn.Linear(768, self.lm.config.hidden_size) # whether bias is needed?
89
+ # LM to ViT: 6656 -> 1792
90
+ self.project_down = nn.Linear(self.lm.config.hidden_size, 768)
91
+
92
+ if 'qformer' in self.model_config.bridge.name.lower():
93
+ from transformers import BertTokenizer
94
+ self.qformer_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", truncation_side="left", local_files_only=True)
95
+ self.qformer_tokenizer.add_special_tokens({"bos_token": "[DEC]"})
96
+ self.qformer_tokenizer.padding_side = "left"
97
+ if self.model_config.bridge.name == 'qformer':
98
+ self.qformer, self.query_tokens = build_qformer(
99
+ self.model_config.bridge.num_query_token, self.model_config.vision_encoder.encoder_embed_dim,
100
+ qformer_hidden_dropout_prob=self.model_config.bridge.qformer_hidden_dropout_prob,
101
+ qformer_attention_probs_dropout_prob=self.model_config.bridge.qformer_attention_probs_dropout_prob,
102
+ qformer_drop_path_rate=self.model_config.bridge.qformer_drop_path_rate,
103
+ )
104
+ elif self.model_config.bridge.name == 'causal_qformer':
105
+ self.qformer, self.query_tokens = build_causal_qformer(
106
+ self.model_config.bridge.num_query_token, self.model_config.vision_encoder.encoder_embed_dim,
107
+ qformer_hidden_dropout_prob=self.model_config.bridge.qformer_hidden_dropout_prob,
108
+ qformer_attention_probs_dropout_prob=self.model_config.bridge.qformer_attention_probs_dropout_prob
109
+ )
110
+ self.qformer.resize_token_embeddings(len(self.qformer_tokenizer))
111
+ self.qformer.cls = None
112
+ self.extra_num_query_token = self.model_config.bridge.extra_num_query_token
113
+ if self.model_config.bridge.extra_num_query_token > 0:
114
+ logger.info(f"Add extra {self.model_config.bridge.extra_num_query_token} tokens in QFormer")
115
+ self.extra_query_tokens = nn.Parameter(
116
+ torch.zeros(1, self.model_config.bridge.extra_num_query_token, self.query_tokens.shape[-1])
117
+ )
118
+
119
+ self.freeze_bridge = self.model_config.get("freeze_bridge", False)
120
+ if self.freeze_bridge:
121
+ logger.info("freeze bridge")
122
+ freeze_module(self.qformer)
123
+ self.query_tokens.requires_grad = False
124
+
125
+ def build_llm(self):
126
+ self.lm_name = self.model_config.llm.name
127
+ if self.model_config.llm.name == "vicuna1.5_7b":
128
+ self.lm = LlamaForCausalLM.from_pretrained(self.model_config.llm.pretrained_llm_path)
129
+ self.lm.gradient_checkpointing = self.model_config.llm.get("use_llama_gradient_checkpointing", True)
130
+ elif self.model_config.llm.name == 'mistral_7b':
131
+ from transformers import AutoModelForCausalLM
132
+ self.lm = AutoModelForCausalLM.from_pretrained(
133
+ self.model_config.llm.pretrained_llm_path,
134
+ torch_dtype=torch.bfloat16,
135
+ # attn_implementation="flash_attention_2",
136
+ )
137
+ elif self.model_config.llm.name == 'internlm_20b':
138
+ from transformers import AutoModelForCausalLM
139
+ self.lm = AutoModelForCausalLM.from_pretrained(
140
+ self.model_config.llm.pretrained_llm_path,
141
+ torch_dtype=torch.bfloat16,
142
+ trust_remote_code=True,
143
+ )
144
+ self.lm.gradient_checkpointing = True
145
+ self.lm._set_gradient_checkpointing()
146
+ elif self.model_config.llm.name == 'internlm2_5_7b':
147
+ from transformers import AutoModelForCausalLM
148
+ self.lm = AutoModelForCausalLM.from_pretrained(
149
+ self.model_config.llm.pretrained_llm_path,
150
+ torch_dtype=torch.bfloat16,
151
+ trust_remote_code=True,
152
+ local_files_only=True,
153
+ )
154
+ else:
155
+ raise NotImplementedError(self.model_config.llm.name)
156
+
157
+ self.freeze_llm = self.model_config.get("freeze_llm", True)
158
+ logger.info(f'freeze_llm: {self.freeze_llm}')
159
+ if self.freeze_llm:
160
+ logger.info("freeze llm")
161
+ freeze_module(self.lm)
162
+
163
+ if self.model_config.llm.use_lora:
164
+ self.use_lora = True
165
+ from peft import get_peft_model, LoraConfig, TaskType
166
+ logger.info("Use lora")
167
+ if self.model_config.llm.name == 'internlm_20b':
168
+ peft_config = LoraConfig(
169
+ task_type=TaskType.CAUSAL_LM, inference_mode=False,
170
+ r=self.model_config.llm.lora_r, lora_alpha=self.model_config.llm.lora_alpha, lora_dropout=self.model_config.llm.lora_dropout,
171
+ target_modules=['wqkv', 'wo', 'w1', 'w2', 'w3', 'output']
172
+ )
173
+ else:
174
+ peft_config = LoraConfig(
175
+ task_type=TaskType.CAUSAL_LM, inference_mode=False,
176
+ r=self.model_config.llm.lora_r, lora_alpha=self.model_config.llm.lora_alpha, lora_dropout=self.model_config.llm.lora_dropout,
177
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
178
+ "gate_proj", "up_proj", "down_proj", "lm_head"]
179
+ )
180
+
181
+ self.lm = get_peft_model(self.lm, peft_config)
182
+ self.lm.enable_input_require_grads()
183
+ self.lm.print_trainable_parameters()
184
+ else:
185
+ self.use_lora = False
186
+
187
+
188
+ def build_loss(self):
189
+ self.use_vision_regression_loss = self.model_config.loss.get("use_vision_regression_loss", False)
190
+ if self.use_vision_regression_loss:
191
+ self.image_loss_fct = MSELoss()
192
+
193
+ @property
194
+ def dtype(self):
195
+ return self.lm.dtype
196
+
197
+
198
+ @property
199
+ def device(self):
200
+ return self.lm.device
modeling_internvideo2_vit.py ADDED
@@ -0,0 +1,983 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import logging
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
6
+ from torch import nn
7
+
8
+ import torch.utils.checkpoint as checkpoint
9
+ from functools import partial
10
+ from einops import rearrange
11
+ from .flash_attention_class import FlashAttention
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+ try:
16
+ from flash_attn.modules.mlp import FusedMLP
17
+ except:
18
+ logger.warn(f'FusedMLP of flash_attn is not installed!!!')
19
+
20
+ try:
21
+ from flash_attn.ops.rms_norm import DropoutAddRMSNorm
22
+ except:
23
+ logger.warn(f'DropoutAddRMSNorm of flash_attn is not installed!!!')
24
+
25
+ import numpy as np
26
+ import torch
27
+ import logging
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+ # --------------------------------------------------------
32
+ # 3D sine-cosine position embedding
33
+ # References:
34
+ # MVD: https://github.com/ruiwang2021/mvd/blob/main/modeling_finetune.py
35
+ # --------------------------------------------------------
36
+ def get_3d_sincos_pos_embed(embed_dim, grid_size, t_size, cls_token=False):
37
+ """
38
+ grid_size: int of the grid height and width
39
+ t_size: int of the temporal size
40
+ return:
41
+ pos_embed: [t_size*grid_size*grid_size, embed_dim] or [1+t_size*grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
42
+ """
43
+ assert embed_dim % 4 == 0
44
+ embed_dim_spatial = embed_dim // 4 * 3
45
+ embed_dim_temporal = embed_dim // 4
46
+
47
+ # spatial
48
+ grid_h = np.arange(grid_size, dtype=np.float32)
49
+ grid_w = np.arange(grid_size, dtype=np.float32)
50
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
51
+ grid = np.stack(grid, axis=0)
52
+
53
+ grid = grid.reshape([2, 1, grid_size, grid_size])
54
+ pos_embed_spatial = get_2d_sincos_pos_embed_from_grid(
55
+ embed_dim_spatial, grid
56
+ )
57
+
58
+ # temporal
59
+ grid_t = np.arange(t_size, dtype=np.float32)
60
+ pos_embed_temporal = get_1d_sincos_pos_embed_from_grid(
61
+ embed_dim_temporal, grid_t
62
+ )
63
+
64
+ # concate: [T, H, W] order
65
+ pos_embed_temporal = pos_embed_temporal[:, np.newaxis, :]
66
+ pos_embed_temporal = np.repeat(
67
+ pos_embed_temporal, grid_size**2, axis=1
68
+ ) # [T, H*W, D // 4]
69
+ pos_embed_spatial = pos_embed_spatial[np.newaxis, :, :]
70
+ pos_embed_spatial = np.repeat(
71
+ pos_embed_spatial, t_size, axis=0
72
+ ) # [T, H*W, D // 4 * 3]
73
+
74
+ pos_embed = np.concatenate([pos_embed_temporal, pos_embed_spatial], axis=-1)
75
+ pos_embed = pos_embed.reshape([-1, embed_dim]) # [T*H*W, D]
76
+
77
+ if cls_token:
78
+ pos_embed = np.concatenate(
79
+ [np.zeros([1, embed_dim]), pos_embed], axis=0
80
+ )
81
+ return pos_embed
82
+
83
+
84
+ # --------------------------------------------------------
85
+ # 2D sine-cosine position embedding
86
+ # References:
87
+ # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
88
+ # MoCo v3: https://github.com/facebookresearch/moco-v3
89
+ # --------------------------------------------------------
90
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
91
+ """
92
+ grid_size: int of the grid height and width
93
+ return:
94
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
95
+ """
96
+ grid_h = np.arange(grid_size, dtype=np.float32)
97
+ grid_w = np.arange(grid_size, dtype=np.float32)
98
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
99
+ grid = np.stack(grid, axis=0)
100
+
101
+ grid = grid.reshape([2, 1, grid_size, grid_size])
102
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
103
+ if cls_token:
104
+ pos_embed = np.concatenate(
105
+ [np.zeros([1, embed_dim]), pos_embed], axis=0
106
+ )
107
+ return pos_embed
108
+
109
+
110
+ def get_1d_sincos_pos_embed(embed_dim, t_size, cls_token=False):
111
+ """
112
+ t_size: int of the temporal size
113
+ return:
114
+ pos_embed: [t_size, embed_dim] or [1+t_size, embed_dim] (w/ or w/o cls_token)
115
+ """
116
+ grid_t = np.arange(t_size, dtype=np.float32)
117
+ pos_embed = get_1d_sincos_pos_embed_from_grid(embed_dim, grid_t)
118
+ if cls_token:
119
+ pos_embed = np.concatenate(
120
+ [np.zeros([1, embed_dim]), pos_embed], axis=0
121
+ )
122
+ return pos_embed
123
+
124
+
125
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
126
+ assert embed_dim % 2 == 0
127
+
128
+ # use half of dimensions to encode grid_h
129
+ emb_h = get_1d_sincos_pos_embed_from_grid(
130
+ embed_dim // 2, grid[0]
131
+ ) # (H*W, D/2)
132
+ emb_w = get_1d_sincos_pos_embed_from_grid(
133
+ embed_dim // 2, grid[1]
134
+ ) # (H*W, D/2)
135
+
136
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
137
+ return emb
138
+
139
+
140
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
141
+ """
142
+ embed_dim: output dimension for each position
143
+ pos: a list of positions to be encoded: size (M,)
144
+ out: (M, D)
145
+ """
146
+ assert embed_dim % 2 == 0
147
+ omega = np.arange(embed_dim // 2, dtype=np.float32)
148
+ omega /= embed_dim / 2.0
149
+ omega = 1.0 / 10000**omega # (D/2,)
150
+
151
+ pos = pos.reshape(-1) # (M,)
152
+ out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
153
+
154
+ emb_sin = np.sin(out) # (M, D/2)
155
+ emb_cos = np.cos(out) # (M, D/2)
156
+
157
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
158
+ return emb
159
+
160
+
161
+ def interpolate_pos_embed_internvideo2(checkpoint_model, model, orig_t_size = 8):
162
+ # interpolate position embedding
163
+ for pos_name in ['pos_embed', 'clip_pos_embed']:
164
+ if pos_name in checkpoint_model:
165
+ pos_embed_checkpoint = checkpoint_model[pos_name]
166
+ embedding_size = pos_embed_checkpoint.shape[-1] # channel dim
167
+ num_patches = model.patch_embed.num_patches #
168
+ num_extra_tokens = model.pos_embed.shape[-2] - num_patches # 0/1
169
+
170
+ # we use 8 frames for pretraining
171
+ # new_t_size = args.num_frames * args.num_segments // model.patch_embed.tubelet_size
172
+ new_t_size = model.num_frames // model.tubelet_size
173
+ # height (== width) for the checkpoint position embedding
174
+ orig_size = int(((pos_embed_checkpoint.shape[-2] - num_extra_tokens)//(orig_t_size)) ** 0.5)
175
+ # height (== width) for the new position embedding
176
+ new_size = int((num_patches // (new_t_size))** 0.5)
177
+
178
+ # class_token and dist_token are kept unchanged
179
+ if orig_t_size != new_t_size:
180
+ logger.info(f"Temporal interpolate from {orig_t_size} to {new_t_size} ({pos_name})")
181
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
182
+ # only the position tokens are interpolated
183
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
184
+ # B, L, C -> B, T, HW, C -> BHW, C, T (B = 1)
185
+ pos_tokens = pos_tokens.view(1, orig_t_size, -1, embedding_size)
186
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).reshape(-1, embedding_size, orig_t_size)
187
+ pos_tokens = torch.nn.functional.interpolate(pos_tokens, size=new_t_size, mode='linear')
188
+ pos_tokens = pos_tokens.view(1, -1, embedding_size, new_t_size)
189
+ pos_tokens = pos_tokens.permute(0, 3, 1, 2).reshape(1, -1, embedding_size)
190
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
191
+ checkpoint_model[pos_name] = new_pos_embed
192
+ pos_embed_checkpoint = new_pos_embed
193
+
194
+ # class_token and dist_token are kept unchanged
195
+ if orig_size != new_size:
196
+ logger.info(f"Position interpolate from {orig_size}x{orig_size} to {new_size}x{new_size} ({pos_name})")
197
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
198
+ # only the position tokens are interpolated
199
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
200
+ # B, L, C -> BT, H, W, C -> BT, C, H, W
201
+ pos_tokens = pos_tokens.reshape(-1, new_t_size, orig_size, orig_size, embedding_size)
202
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
203
+ pos_tokens = torch.nn.functional.interpolate(
204
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
205
+ # BT, C, H, W -> BT, H, W, C -> B, T, H, W, C
206
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).reshape(-1, new_t_size, new_size, new_size, embedding_size)
207
+ pos_tokens = pos_tokens.flatten(1, 3) # B, L, C
208
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
209
+ checkpoint_model[pos_name] = new_pos_embed
210
+
211
+
212
+ if 'pos_embed_spatial' in checkpoint_model or 'pos_embed_temporal' in checkpoint_model:
213
+ raise NotImplementedError
214
+
215
+ def interpolate_pos_embed_internvideo2_new(checkpoint_model, model, orig_t_size = 8):
216
+ pos_names = []
217
+ for k in checkpoint_model.keys():
218
+ if ('pos_embed' in k or 'clip_pos_embed' in k) and 'img_pos_embed' not in k: # NOTE 暂时不插值img_pos,高分辨率时可能需要再加
219
+ pos_names.append(k)
220
+
221
+ logger.info(f"pos names list for interpolating: {pos_names}")
222
+
223
+ assert len(pos_names) > 0, checkpoint_model.keys()
224
+
225
+ if 'pos_embed_spatial' in checkpoint_model.keys() or 'pos_embed_temporal' in checkpoint_model.keys():
226
+ raise NotImplementedError
227
+
228
+ # interpolate position embedding
229
+ for pos_name in pos_names:
230
+
231
+ pos_embed_checkpoint = checkpoint_model[pos_name]
232
+ embedding_size = pos_embed_checkpoint.shape[-1] # channel dim
233
+ num_patches = model.patch_embed.num_patches #
234
+ num_extra_tokens = model.pos_embed.shape[-2] - num_patches # 0/1
235
+
236
+ # we use 8 frames for pretraining
237
+ # new_t_size = args.num_frames * args.num_segments // model.patch_embed.tubelet_size
238
+ new_t_size = model.num_frames // model.tubelet_size
239
+ # height (== width) for the checkpoint position embedding
240
+ orig_size = int(((pos_embed_checkpoint.shape[-2] - num_extra_tokens)//(orig_t_size)) ** 0.5)
241
+ # height (== width) for the new position embedding
242
+ new_size = int((num_patches // (new_t_size))** 0.5)
243
+
244
+ # class_token and dist_token are kept unchanged
245
+ if orig_t_size != new_t_size:
246
+ logger.info(f"Temporal interpolate from {orig_t_size} to {new_t_size} ({pos_name})")
247
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
248
+ # only the position tokens are interpolated
249
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
250
+ # B, L, C -> B, T, HW, C -> BHW, C, T (B = 1)
251
+ pos_tokens = pos_tokens.view(1, orig_t_size, -1, embedding_size)
252
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).reshape(-1, embedding_size, orig_t_size)
253
+ pos_tokens = torch.nn.functional.interpolate(pos_tokens, size=new_t_size, mode='linear')
254
+ pos_tokens = pos_tokens.view(1, -1, embedding_size, new_t_size)
255
+ pos_tokens = pos_tokens.permute(0, 3, 1, 2).reshape(1, -1, embedding_size)
256
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
257
+ checkpoint_model[pos_name] = new_pos_embed
258
+ pos_embed_checkpoint = new_pos_embed
259
+
260
+ # class_token and dist_token are kept unchanged
261
+ if orig_size != new_size:
262
+ logger.info(f"Position interpolate from {orig_size}x{orig_size} to {new_size}x{new_size} ({pos_name})")
263
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
264
+ # only the position tokens are interpolated
265
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
266
+ # B, L, C -> BT, H, W, C -> BT, C, H, W
267
+ pos_tokens = pos_tokens.reshape(-1, new_t_size, orig_size, orig_size, embedding_size)
268
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
269
+ pos_tokens = torch.nn.functional.interpolate(
270
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
271
+ # BT, C, H, W -> BT, H, W, C -> B, T, H, W, C
272
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).reshape(-1, new_t_size, new_size, new_size, embedding_size)
273
+ pos_tokens = pos_tokens.flatten(1, 3) # B, L, C
274
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
275
+ checkpoint_model[pos_name] = new_pos_embed
276
+
277
+
278
+
279
+ def interpolate_pos_embed(checkpoint_model, model, orig_t_size=4, pos_name='vision_encoder.pos_embed'):
280
+ if pos_name in checkpoint_model:
281
+ pos_embed_checkpoint = checkpoint_model[pos_name]
282
+ embedding_size = pos_embed_checkpoint.shape[-1] # channel dim
283
+ num_patches = model.patch_embed.num_patches #
284
+ num_extra_tokens = model.pos_embed.shape[-2] - num_patches # 0/1
285
+
286
+ # we use 4 frames for pretraining
287
+ new_t_size = model.T
288
+ # height (== width) for the checkpoint position embedding
289
+ orig_size = int(((pos_embed_checkpoint.shape[-2] - num_extra_tokens)//(orig_t_size)) ** 0.5)
290
+ # height (== width) for the new position embedding
291
+ new_size = int((num_patches // (new_t_size))** 0.5)
292
+
293
+ # class_token and dist_token are kept unchanged
294
+ if orig_t_size != new_t_size:
295
+ print(f"Temporal interpolate from {orig_t_size} to {new_t_size} ({pos_name})")
296
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
297
+ # only the position tokens are interpolated
298
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
299
+ # B, L, C -> B, T, HW, C -> BHW, C, T (B = 1)
300
+ pos_tokens = pos_tokens.view(1, orig_t_size, -1, embedding_size)
301
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).reshape(-1, embedding_size, orig_t_size)
302
+ pos_tokens = torch.nn.functional.interpolate(pos_tokens, size=new_t_size, mode='linear')
303
+ pos_tokens = pos_tokens.view(1, -1, embedding_size, new_t_size)
304
+ pos_tokens = pos_tokens.permute(0, 3, 1, 2).reshape(1, -1, embedding_size)
305
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
306
+ checkpoint_model[pos_name] = new_pos_embed
307
+ pos_embed_checkpoint = new_pos_embed
308
+
309
+ # class_token and dist_token are kept unchanged
310
+ if orig_size != new_size:
311
+ print(f"Position interpolate from {orig_size}x{orig_size} to {new_size}x{new_size} ({pos_name})")
312
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
313
+ # only the position tokens are interpolated
314
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
315
+ # B, L, C -> BT, H, W, C -> BT, C, H, W
316
+ pos_tokens = pos_tokens.reshape(-1, new_t_size, orig_size, orig_size, embedding_size)
317
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
318
+ pos_tokens = torch.nn.functional.interpolate(
319
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
320
+ # BT, C, H, W -> BT, H, W, C -> B, T, H, W, C
321
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).reshape(-1, new_t_size, new_size, new_size, embedding_size)
322
+ pos_tokens = pos_tokens.flatten(1, 3) # B, L, C
323
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
324
+ checkpoint_model[pos_name] = new_pos_embed
325
+ else:
326
+ raise NotImplementedError
327
+
328
+
329
+
330
+ class CrossAttention(nn.Module):
331
+ def __init__(
332
+ self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
333
+ proj_drop=0., attn_head_dim=None, out_dim=None):
334
+ super().__init__()
335
+ if out_dim is None:
336
+ out_dim = dim
337
+ self.num_heads = num_heads
338
+ head_dim = dim // num_heads
339
+ if attn_head_dim is not None:
340
+ head_dim = attn_head_dim
341
+ all_head_dim = head_dim * self.num_heads
342
+ self.scale = qk_scale or head_dim ** -0.5
343
+ assert all_head_dim == dim
344
+
345
+ self.q = nn.Linear(dim, all_head_dim, bias=False)
346
+ self.k = nn.Linear(dim, all_head_dim, bias=False)
347
+ self.v = nn.Linear(dim, all_head_dim, bias=False)
348
+
349
+ if qkv_bias:
350
+ self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
351
+ self.k_bias = nn.Parameter(torch.zeros(all_head_dim))
352
+ self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
353
+ else:
354
+ self.q_bias = None
355
+ self.k_bias = None
356
+ self.v_bias = None
357
+
358
+ self.attn_drop = nn.Dropout(attn_drop)
359
+ self.proj = nn.Linear(all_head_dim, out_dim)
360
+ self.proj_drop = nn.Dropout(proj_drop)
361
+
362
+ def forward(self, x, k=None, v=None):
363
+ B, N, C = x.shape
364
+ N_k = k.shape[1]
365
+ N_v = v.shape[1]
366
+
367
+ q_bias, k_bias, v_bias = None, None, None
368
+ if self.q_bias is not None:
369
+ q_bias = self.q_bias
370
+ k_bias = self.k_bias
371
+ v_bias = self.v_bias
372
+
373
+ q = F.linear(input=x, weight=self.q.weight, bias=q_bias)
374
+ q = q.reshape(B, N, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0) # (B, N_head, N_q, dim)
375
+
376
+ k = F.linear(input=k, weight=self.k.weight, bias=k_bias)
377
+ k = k.reshape(B, N_k, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0)
378
+
379
+ v = F.linear(input=v, weight=self.v.weight, bias=v_bias)
380
+ v = v.reshape(B, N_v, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0)
381
+
382
+ q = q * self.scale
383
+ attn = (q @ k.transpose(-2, -1)) # (B, N_head, N_q, N_k)
384
+
385
+ attn = attn.softmax(dim=-1)
386
+ attn = self.attn_drop(attn)
387
+
388
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
389
+ x = self.proj(x)
390
+ x = self.proj_drop(x)
391
+
392
+ return x
393
+
394
+
395
+ class AttentiveBlock(nn.Module):
396
+
397
+ def __init__(self, dim, num_heads, qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
398
+ drop_path=0., norm_layer=nn.LayerNorm, attn_head_dim=None, out_dim=None):
399
+ super().__init__()
400
+
401
+ self.norm1_q = norm_layer(dim)
402
+ self.norm1_k = norm_layer(dim)
403
+ self.norm1_v = norm_layer(dim)
404
+ self.cross_attn = CrossAttention(
405
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop,
406
+ proj_drop=drop, attn_head_dim=attn_head_dim, out_dim=out_dim)
407
+
408
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
409
+
410
+ def forward(self, x_q, x_kv, pos_q, pos_k, bool_masked_pos, rel_pos_bias=None):
411
+ x_q = self.norm1_q(x_q + pos_q)
412
+ x_k = self.norm1_k(x_kv + pos_k)
413
+ x_v = self.norm1_v(x_kv)
414
+ x = self.cross_attn(x_q, k=x_k, v=x_v)
415
+
416
+ return x
417
+
418
+
419
+ class AttentionPoolingBlock(AttentiveBlock):
420
+
421
+ def forward(self, x):
422
+ x_q = x.mean(1, keepdim=True)
423
+ x_kv, pos_q, pos_k = x, 0, 0
424
+ x = super().forward(x_q, x_kv, pos_q, pos_k, bool_masked_pos=None, rel_pos_bias=None)
425
+ x = x.squeeze(1)
426
+ return x
427
+
428
+
429
+ class RMSNorm(nn.Module):
430
+ def __init__(self, hidden_size, eps=1e-6):
431
+ super().__init__()
432
+ self.weight = nn.Parameter(torch.ones(hidden_size))
433
+ self.variance_epsilon = eps
434
+
435
+ def forward(self, hidden_states):
436
+ input_dtype = hidden_states.dtype
437
+ hidden_states = hidden_states.to(torch.float32)
438
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
439
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
440
+ return self.weight * hidden_states.to(input_dtype)
441
+
442
+
443
+ class Attention(nn.Module):
444
+ def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., use_flash_attn=False,
445
+ causal=False, norm_layer=nn.LayerNorm, qk_normalization=False, use_fused_rmsnorm=False):
446
+ super().__init__()
447
+ assert dim % num_heads == 0, 'dim should be divisible by num_heads'
448
+ self.num_heads = num_heads
449
+ head_dim = dim // num_heads
450
+ self.scale = head_dim ** -0.5
451
+
452
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
453
+ self.attn_drop = nn.Dropout(attn_drop)
454
+ self.proj = nn.Linear(dim, dim)
455
+ self.proj_drop = nn.Dropout(proj_drop)
456
+
457
+ self.use_flash_attn = use_flash_attn
458
+ if use_flash_attn:
459
+ self.causal = causal
460
+ self.inner_attn = FlashAttention(attention_dropout=attn_drop)
461
+
462
+ self.qk_normalization = qk_normalization
463
+ self.q_norm = norm_layer(dim) if qk_normalization else nn.Identity()
464
+ self.k_norm = norm_layer(dim) if qk_normalization else nn.Identity()
465
+ self.use_fused_rmsnorm = use_fused_rmsnorm
466
+
467
+ def _naive_attn(self, x):
468
+ B, N, C = x.shape
469
+ # print(x.shape, torch.cuda.memory_allocated(), torch.cuda.memory_allocated())
470
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
471
+ q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
472
+
473
+ if self.qk_normalization:
474
+ B_, H_, N_, D_ = q.shape
475
+ q = self.q_norm(q.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
476
+ k = self.k_norm(k.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
477
+
478
+ attn = ((q * self.scale) @ k.transpose(-2, -1))
479
+ # attn = attn - attn.max(-1)[0].unsqueeze(-1) # in case of overflow for fp16
480
+ attn = attn.softmax(dim=-1)
481
+ attn = self.attn_drop(attn)
482
+ # print(torch.cuda.memory_allocated(), torch.cuda.memory_allocated())
483
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
484
+ # print(f"\033[31m这{x.device}是{self.proj.weight.device} {self.proj.bias.device}\033[0m")
485
+ # print(f"\033[31m类型{x.dtype}是{self.proj.weight.dtype} {self.proj.bias.dtype}\033[0m")
486
+ x = self.proj(x)
487
+ x = self.proj_drop(x)
488
+ return x
489
+
490
+ def _flash_attn(self, x, key_padding_mask=None, need_weights=False):
491
+
492
+ qkv = self.qkv(x)
493
+ qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, h=self.num_heads)
494
+
495
+ if self.qk_normalization:
496
+ q, k, v = qkv.unbind(2)
497
+ if self.use_fused_rmsnorm:
498
+ q = self.q_norm(q.flatten(-2, -1))[0].view(q.shape)
499
+ k = self.k_norm(k.flatten(-2, -1))[0].view(k.shape)
500
+ else:
501
+ q = self.q_norm(q.flatten(-2, -1)).view(q.shape)
502
+ k = self.k_norm(k.flatten(-2, -1)).view(k.shape)
503
+ qkv = torch.stack([q, k, v], dim=2)
504
+
505
+ context, _ = self.inner_attn(
506
+ qkv, key_padding_mask=key_padding_mask, need_weights=need_weights, causal=self.causal
507
+ )
508
+ outs = self.proj(rearrange(context, "b s h d -> b s (h d)"))
509
+ outs = self.proj_drop(outs)
510
+ return outs
511
+
512
+ def forward(self, x):
513
+ x = self._naive_attn(x) if not self.use_flash_attn else self._flash_attn(x)
514
+ return x
515
+
516
+
517
+ class Mlp(nn.Module):
518
+ """ MLP as used in Vision Transformer, MLP-Mixer and related networks
519
+ """
520
+
521
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU,
522
+ bias=True, drop=0.):
523
+ super().__init__()
524
+ out_features = out_features or in_features
525
+ hidden_features = hidden_features or in_features
526
+ bias = to_2tuple(bias)
527
+ drop_probs = to_2tuple(drop)
528
+
529
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0])
530
+ self.act = act_layer()
531
+ self.drop1 = nn.Dropout(drop_probs[0])
532
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1])
533
+ self.drop2 = nn.Dropout(drop_probs[1])
534
+
535
+ def forward(self, x):
536
+ x = self.fc1(x)
537
+ x = self.act(x)
538
+ x = self.drop1(x)
539
+ x = self.fc2(x)
540
+ x = self.drop2(x)
541
+ return x
542
+
543
+
544
+ class Block(nn.Module):
545
+
546
+ def __init__(
547
+ self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., init_values=None,
548
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_flash_attn=False, use_fused_mlp=False,
549
+ fused_mlp_heuristic=1, with_cp=False, qk_normalization=False, layerscale_no_force_fp32=False,
550
+ use_fused_rmsnorm=False):
551
+ super().__init__()
552
+
553
+ self.norm1 = norm_layer(dim)
554
+ self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop,
555
+ use_flash_attn=use_flash_attn, causal=False, norm_layer=norm_layer,
556
+ qk_normalization=qk_normalization,
557
+ use_fused_rmsnorm=use_fused_rmsnorm)
558
+ self.ls1 = nn.Parameter(init_values * torch.ones(dim))
559
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
560
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
561
+
562
+ self.norm2 = norm_layer(dim)
563
+ mlp_hidden_dim = int(dim * mlp_ratio)
564
+ if use_fused_mlp:
565
+ self.mlp = FusedMLP(in_features=dim, hidden_features=mlp_hidden_dim, heuristic=fused_mlp_heuristic)
566
+ else:
567
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
568
+ self.ls2 = nn.Parameter(init_values * torch.ones(dim))
569
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
570
+
571
+ self.with_cp = with_cp
572
+ self.use_fused_rmsnorm = use_fused_rmsnorm
573
+
574
+ def forward(self, x, residual=None):
575
+
576
+ def _inner_forward(x, residual=None):
577
+ if self.use_fused_rmsnorm:
578
+ x, residual = self.norm1(x, residual)
579
+ x = self.drop_path1(self.ls1 * self.attn(x) )
580
+ x, residual = self.norm2(x, residual)
581
+ x = self.drop_path2(self.ls2 * self.mlp(x) )
582
+ return x, residual
583
+ else:
584
+ assert residual is None
585
+ x = x + self.drop_path1(self.ls1 * self.attn(self.norm1(x)))
586
+ x = x + self.drop_path2(self.ls2 * self.mlp(self.norm2(x)))
587
+ return x
588
+
589
+ if self.with_cp:
590
+ # print(f"\033[31m use_checkpoint [0m")
591
+ return checkpoint.checkpoint(_inner_forward, x, residual)
592
+ else:
593
+ return _inner_forward(x, residual=residual)
594
+
595
+
596
+ class PatchEmbed(nn.Module):
597
+ """ 3D Image to Patch Embedding
598
+ """
599
+
600
+ def __init__(
601
+ self, img_size=224, patch_size=16, in_chans=3, embed_dim=768,
602
+ num_frames=8, tubelet_size=1, norm_layer=None
603
+ ):
604
+ super().__init__()
605
+ img_size = to_2tuple(img_size)
606
+ patch_size = to_2tuple(patch_size)
607
+ self.tubelet_size = tubelet_size
608
+ self.img_size = img_size
609
+ self.patch_size = patch_size
610
+ self.grid_size = (
611
+ num_frames // tubelet_size,
612
+ img_size[0] // patch_size[0],
613
+ img_size[1] // patch_size[1]
614
+ ) # (T, H, W)
615
+ self.num_patches = self.grid_size[0] * self.grid_size[1] * self.grid_size[2]
616
+ self.num_img_patches = self.grid_size[1] * self.grid_size[2]
617
+
618
+ self.proj = nn.Conv3d(
619
+ in_channels=in_chans, out_channels=embed_dim,
620
+ kernel_size=(tubelet_size, patch_size[0], patch_size[1]),
621
+ stride=(tubelet_size, patch_size[0], patch_size[1])
622
+ )
623
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
624
+
625
+ def forward(self, x):
626
+ x = self.proj(x)
627
+ x = x.flatten(3).permute(0, 2, 3, 1) # B x C x T x HW => B x T x HW x C
628
+ x = self.norm(x)
629
+ return x
630
+
631
+ class PretrainVisionTransformer_clean(nn.Module):
632
+ def __init__(
633
+ self,
634
+ in_chans: int = 3,
635
+ patch_size: int = 14,
636
+ img_size: int = 224,
637
+ qkv_bias: bool = False, # follow internvl_clip to set False
638
+ drop_path_rate: float = 0.25, # may need ablation
639
+ embed_dim: int = 1408,
640
+ num_heads: int = 16,
641
+ mlp_ratio: float = 48/11,
642
+ init_values: float = 1e-5, # may need ablation
643
+ qk_normalization: bool = True,
644
+ depth: int = 40,
645
+ use_flash_attn: bool = True,
646
+ use_fused_rmsnorm: bool = True,
647
+ use_fused_mlp: bool = True,
648
+ fused_mlp_heuristic: int = 1,
649
+ attn_pool_num_heads: int = 16,
650
+ clip_embed_dim: int = 768,
651
+ layerscale_no_force_fp32: bool = False, # whether True for training?
652
+ num_frames: int = 8,
653
+ tubelet_size: int = 1,
654
+ sep_pos_embed: bool = False,
655
+ sep_image_video_pos_embed: bool = False,
656
+ use_checkpoint: bool = False,
657
+ checkpoint_num: int = 0,
658
+ # for unmasked teacher
659
+ x_vis_return_idx=-1,
660
+ x_vis_only=False
661
+ ):
662
+ super().__init__()
663
+
664
+ self.num_frames = num_frames
665
+ self.tubelet_size = tubelet_size
666
+ assert use_flash_attn == use_fused_rmsnorm == use_fused_mlp, 'use_flash_attn, use_fused_rmsnorm and use_fused_mlp should be consistent'
667
+
668
+ self.use_flash_attn = use_flash_attn
669
+ self.embed_dim = embed_dim
670
+
671
+ logger.info(f"Origin depth: {depth}")
672
+ depth = depth + x_vis_return_idx + 1
673
+ logger.info(f"New depth: {depth}")
674
+ self.depth = depth
675
+
676
+ self.x_vis_only = x_vis_only
677
+
678
+ if use_fused_rmsnorm:
679
+ norm_layer_for_blocks = partial(DropoutAddRMSNorm, eps=1e-6, prenorm=True)
680
+ else:
681
+ norm_layer_for_blocks = partial(RMSNorm, eps=1e-6)
682
+ self.norm_layer_for_blocks = norm_layer_for_blocks
683
+ self.patch_embed = PatchEmbed(
684
+ img_size, patch_size, in_chans, embed_dim,
685
+ num_frames=num_frames, tubelet_size=tubelet_size,
686
+ )
687
+ num_patches = self.patch_embed.num_patches
688
+ num_img_patches = self.patch_embed.num_img_patches
689
+ # print(f"num_patches: {num_patches}, num_img_patches: {num_img_patches}")
690
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
691
+
692
+ # stolen from https://github.com/facebookresearch/mae_st/blob/dc072aaaf640d06892e23a33b42223a994efe272/models_vit.py#L65-L73C17
693
+ self.sep_pos_embed = sep_pos_embed
694
+ self.sep_image_video_pos_embed = sep_image_video_pos_embed
695
+ if sep_pos_embed:
696
+ raise NotImplementedError
697
+ else:
698
+ if sep_image_video_pos_embed:
699
+ logger.info("Use joint position embedding, for image and video we use different pos_embed.")
700
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
701
+ self.img_pos_embed = nn.Parameter(torch.zeros(1, num_img_patches + 1, embed_dim))
702
+ else:
703
+ logger.info("Use joint position embedding, for image and video we use same pos_embed.")
704
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
705
+
706
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
707
+ # choose which layer to use checkpoint
708
+ with_cp_list = [False] * depth
709
+ if use_checkpoint:
710
+ for idx in range(depth):
711
+ if idx < checkpoint_num:
712
+ with_cp_list[idx] = True
713
+ logger.info(f"Droppath rate: {dpr}")
714
+ logger.info(f"Checkpoint list: {with_cp_list}")
715
+
716
+ self.blocks = nn.ModuleList([
717
+ Block(embed_dim, num_heads, mlp_ratio, qkv_bias=qkv_bias,
718
+ norm_layer=norm_layer_for_blocks,
719
+ drop_path=dpr[i], init_values=init_values, attn_drop=0.,
720
+ use_flash_attn=use_flash_attn, use_fused_mlp=use_fused_mlp,
721
+ fused_mlp_heuristic=fused_mlp_heuristic,
722
+ with_cp=with_cp_list[i],
723
+ qk_normalization=qk_normalization,
724
+ layerscale_no_force_fp32=layerscale_no_force_fp32,
725
+ use_fused_rmsnorm=use_fused_rmsnorm)
726
+ for i in range(depth)])
727
+
728
+ if not self.x_vis_only:
729
+ self.clip_projector = AttentionPoolingBlock(
730
+ dim=embed_dim, num_heads=attn_pool_num_heads, qkv_bias=True, qk_scale=None,
731
+ drop=0., attn_drop=0., norm_layer=partial(nn.LayerNorm, eps=1e-5), out_dim=clip_embed_dim)
732
+
733
+
734
+
735
+ self.init_pos_embed()
736
+ # trunc_normal_(self.cls_token, std=.02)
737
+ # self.apply(self._init_weights)
738
+ # self.fix_init_weight()
739
+
740
+ def init_pos_embed(self):
741
+ logger.info("Init pos_embed from sincos pos_embed")
742
+ if self.sep_pos_embed:
743
+ raise NotImplementedError
744
+ else:
745
+ pos_embed = get_3d_sincos_pos_embed(
746
+ self.pos_embed.shape[-1],
747
+ self.patch_embed.grid_size[1], # height & weight
748
+ self.patch_embed.grid_size[0], # t_size
749
+ cls_token=True
750
+ )
751
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
752
+
753
+ if self.sep_image_video_pos_embed:
754
+ img_pos_embed = get_3d_sincos_pos_embed(
755
+ self.pos_embed.shape[-1],
756
+ self.patch_embed.grid_size[1], # height & weight
757
+ 1,
758
+ cls_token=True
759
+ )
760
+ self.img_pos_embed.data.copy_(torch.from_numpy(img_pos_embed).float().unsqueeze(0))
761
+
762
+
763
+ def _init_weights(self, m):
764
+ if isinstance(m, nn.Linear):
765
+ trunc_normal_(m.weight, std=.02)
766
+ if isinstance(m, nn.Linear) and m.bias is not None:
767
+ nn.init.constant_(m.bias, 0)
768
+ elif isinstance(m, nn.LayerNorm):
769
+ nn.init.constant_(m.bias, 0)
770
+ nn.init.constant_(m.weight, 1.0)
771
+
772
+ def fix_init_weight(self):
773
+ def rescale(param, layer_id):
774
+ param.div_(math.sqrt(2.0 * layer_id))
775
+
776
+ for layer_id, layer in enumerate(self.blocks):
777
+ rescale(layer.attn.proj.weight.data, layer_id + 1)
778
+ rescale(layer.mlp.fc2.weight.data, layer_id + 1)
779
+
780
+ @property
781
+ def dtype(self):
782
+ return self.patch_embed.proj.weight.dtype
783
+
784
+ def get_num_layers(self):
785
+ return len(self.blocks)
786
+
787
+ @torch.jit.ignore
788
+ def no_weight_decay(self):
789
+ return {
790
+ 'pos_embed',
791
+ 'pos_embed_spatial',
792
+ 'pos_embed_temporal',
793
+ 'pos_embed_cls',
794
+ 'img_pos_embed',
795
+ 'cls_token'
796
+ }
797
+
798
+ def expand_pos_embed(self, pos_embed, new_t_size, L, use_vitar_fuzzing=False):
799
+ '''
800
+ @param:
801
+ pos_embed: original pos_embed, (1, T*L + 1, embed_dim)
802
+ T: frames
803
+ L: w * h
804
+ method: interpolation method
805
+ '''
806
+ pos_embed_checkpoint = pos_embed
807
+ embedding_size = pos_embed_checkpoint.shape[-1]
808
+ num_extra_tokens = 1
809
+
810
+ # height (== width) for the checkpoint position embedding
811
+ orig_size = int(((pos_embed_checkpoint.shape[-2] - num_extra_tokens)//(self.num_frames / self.patch_embed.tubelet_size)) ** 0.5)
812
+ # height (== width) for the new position embedding
813
+ new_size = int(L ** 0.5)
814
+
815
+ # class_token and dist_token are kept unchanged
816
+ if self.num_frames != new_t_size:
817
+ logger.info(f"Temporal interpolate from {self.num_frames} to {new_t_size} ")
818
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
819
+ # only the position tokens are interpolated
820
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
821
+ # B, L, C -> B, T, HW, C -> BHW, C, T (B = 1)
822
+ pos_tokens = pos_tokens.view(1, self.num_frames, -1, embedding_size)
823
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).reshape(-1, embedding_size, self.num_frames)
824
+ pos_tokens = torch.nn.functional.interpolate(pos_tokens.cpu(), size=new_t_size, mode='linear').cuda()
825
+ pos_tokens = pos_tokens.view(1, -1, embedding_size, new_t_size)
826
+ pos_tokens = pos_tokens.permute(0, 3, 1, 2).reshape(1, -1, embedding_size)
827
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
828
+ pos_embed_checkpoint = new_pos_embed
829
+
830
+ # class_token and dist_token are kept unchanged
831
+ if orig_size != new_size:
832
+ logger.info(f"Position interpolate from {orig_size}x{orig_size} to {new_size}x{new_size}")
833
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
834
+ # only the position tokens are interpolated
835
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
836
+ # B, L, C -> BT, H, W, C -> BT, C, H, W
837
+ pos_tokens = pos_tokens.reshape(-1, new_t_size, orig_size, orig_size, embedding_size)
838
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
839
+ pos_tokens = torch.nn.functional.interpolate(
840
+ pos_tokens.cpu(), size=(new_size, new_size), mode='bicubic', align_corners=False).cuda()
841
+ # BT, C, H, W -> BT, H, W, C -> B, T, H, W, C
842
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).reshape(-1, new_t_size, new_size, new_size, embedding_size)
843
+ pos_tokens = pos_tokens.flatten(1, 3) # B, L, C
844
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
845
+
846
+ if use_vitar_fuzzing:
847
+ ...
848
+
849
+ return new_pos_embed
850
+
851
+ # @torch.cuda.amp.autocast(enabled=False)
852
+ def forward(self, x, mask=None, use_image=False):
853
+ x = self.patch_embed(x.type(self.dtype))
854
+ # print(f"x.shape: {x.shape} x.dtype: {x.dtype}, model.dtype: {self.dtype}")
855
+ B, T, L, C = x.shape # T: temporal; L: spatial
856
+ x = x.view([B, T * L, C])
857
+
858
+ # append cls token
859
+ cls_tokens = self.cls_token.expand(B, -1, -1)
860
+ x = torch.cat((cls_tokens, x), dim=1)
861
+
862
+ # add pos_embed
863
+ if self.sep_pos_embed:
864
+ raise NotImplementedError
865
+ else:
866
+ if use_image:
867
+ if self.sep_image_video_pos_embed:
868
+ pos_embed = self.img_pos_embed
869
+ else:
870
+ # (1, num_img_patches + 1, embed_dim)
871
+ # print('origin pos_embed.shape:', self.pos_embed.shape)
872
+ cls_pos_embed = self.pos_embed[:, 0:1, :]
873
+ # print('cls_pos_embed.shape:', cls_pos_embed.shape)
874
+
875
+ img_pos_embed = self.pos_embed[:, 1:, :].view(1, self.num_frames, self.patch_embed.num_patches // self.num_frames, self.embed_dim).mean(dim=1)
876
+ # print('img_pos_embed.shape:', img_pos_embed.shape)
877
+
878
+ pos_embed = torch.cat([cls_pos_embed, img_pos_embed], dim=1)
879
+ # print('final img_pos_embed.shape:', pos_embed.shape)
880
+ else:
881
+ pos_embed = self.pos_embed
882
+
883
+ if pos_embed[0].shape != x[0].shape:
884
+ # print(f'pos embed shape {pos_embed.shape} does not match x[0].shape {x[0].shape}')
885
+ pos_embed = self.expand_pos_embed(pos_embed, T, L) # can accelerate here
886
+ assert pos_embed[0].shape == x[0].shape, f'pos embed shape: {pos_embed.shape} not match x[0].shape {x[0].shape}'
887
+ # print("pos_embed.shape:", pos_embed.shape)
888
+ x = x + pos_embed
889
+
890
+ # mask tokens, ~mask means visible
891
+ if mask is not None:
892
+ x = x[~mask].reshape(B, -1, C)
893
+ else:
894
+ x = x.reshape(B, -1, C)
895
+
896
+ residual = None
897
+
898
+ for idx, blk in enumerate(self.blocks):
899
+ if isinstance(x, tuple) and len(x) == 2:
900
+ x, residual = x
901
+ x = blk(x, residual=residual)
902
+
903
+ if isinstance(x, tuple) and len(x) == 2:
904
+ x, residual = x
905
+ if residual is not None:
906
+ x = x + residual
907
+
908
+ x_vis = x
909
+ if self.x_vis_only:
910
+ return x_vis
911
+ else:
912
+ x_pool_vis = self.clip_projector(x_vis)
913
+ return x_vis, x_pool_vis, None, None
914
+
915
+
916
+ def pretrain_internvideo2_giant_patch14_224_clean(config):
917
+ model = PretrainVisionTransformer_clean(
918
+ in_chans=3, img_size=224, patch_size=14,
919
+ embed_dim=1408, depth=40, num_heads=16, mlp_ratio=48/11,
920
+ attn_pool_num_heads=16, qkv_bias=False,
921
+ drop_path_rate=0.25,
922
+ init_values=0.00001,
923
+ qk_normalization=True,
924
+ use_flash_attn=config.vision_encoder.get('use_flash_attn', False),
925
+ use_fused_rmsnorm=config.vision_encoder.get('use_fused_rmsnorm', False),
926
+ use_fused_mlp=config.vision_encoder.get('use_fused_mlp', False),
927
+ fused_mlp_heuristic=1,
928
+ layerscale_no_force_fp32=True,
929
+ num_frames=config.vision_encoder.num_frames,
930
+ tubelet_size=config.vision_encoder.tubelet_size,
931
+ sep_pos_embed=False,
932
+ sep_image_video_pos_embed=config.vision_encoder.sep_image_video_pos_embed,
933
+ use_checkpoint=config.vision_encoder.use_checkpoint,
934
+ checkpoint_num=config.vision_encoder.checkpoint_num,
935
+ x_vis_return_idx=config.vision_encoder.x_vis_return_idx,
936
+ x_vis_only=config.vision_encoder.x_vis_only,
937
+ )
938
+
939
+ if config.vision_encoder.pretrained is not None:
940
+ logger.info(f"Loading pretrained weights from {config.vision_encoder.pretrained}")
941
+ state_dict = torch.load(config.vision_encoder.pretrained, map_location='cpu')
942
+ interpolate_pos_embed_internvideo2(state_dict, model, orig_t_size=4) # NOTE 8f for stage1
943
+ message = model.load_state_dict(state_dict, strict=False)
944
+ logger.info(message)
945
+ else:
946
+ logger.info("No pretrained weights!!!")
947
+ return model
948
+
949
+
950
+
951
+ def pretrain_internvideo2_6b_patch14_224_clean(config):
952
+ model = PretrainVisionTransformer_clean(
953
+ in_chans=3, img_size=224, patch_size=14,
954
+ embed_dim=3200, depth=48, num_heads=25, mlp_ratio=4,
955
+ clip_embed_dim=config.vision_encoder.clip_embed_dim,
956
+ attn_pool_num_heads=16, qkv_bias=False,
957
+ drop_path_rate=0.3,
958
+ init_values=0.00001,
959
+ qk_normalization=True,
960
+ use_flash_attn=config.vision_encoder.get('use_flash_attn', True),
961
+ use_fused_rmsnorm=config.vision_encoder.get('use_fused_rmsnorm', True),
962
+ use_fused_mlp=config.vision_encoder.get('use_fused_mlp', True),
963
+ fused_mlp_heuristic=1,
964
+ layerscale_no_force_fp32=True,
965
+ num_frames=config.vision_encoder.num_frames,
966
+ tubelet_size=config.vision_encoder.tubelet_size,
967
+ sep_pos_embed=False,
968
+ sep_image_video_pos_embed=config.vision_encoder.sep_image_video_pos_embed,
969
+ use_checkpoint=config.vision_encoder.use_checkpoint,
970
+ checkpoint_num=config.vision_encoder.checkpoint_num,
971
+ x_vis_return_idx=config.vision_encoder.x_vis_return_idx,
972
+ x_vis_only=config.vision_encoder.x_vis_only
973
+ )
974
+
975
+ if config.vision_encoder.pretrained is not None:
976
+ logger.info(f"Loading pretrained weights from {config.vision_encoder.pretrained}")
977
+ state_dict = torch.load(config.vision_encoder.pretrained, map_location='cpu')
978
+ interpolate_pos_embed_internvideo2(state_dict, model, orig_t_size=8) # NOTE 8f for stage1
979
+ msg = model.load_state_dict(state_dict, strict=False)
980
+ logger.info(msg)
981
+ else:
982
+ logger.info("No pretrained weights!!!")
983
+ return model
modeling_qformer.py ADDED
@@ -0,0 +1,1270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ * Copyright (c) 2023, salesforce.com, inc.
3
+ * All rights reserved.
4
+ * SPDX-License-Identifier: BSD-3-Clause
5
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ * By Junnan Li
7
+ * Based on huggingface code base
8
+ * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert
9
+ """
10
+ import logging
11
+ import math
12
+ import os
13
+ import warnings
14
+ from dataclasses import dataclass
15
+ from typing import Optional, Tuple, Dict, Any
16
+
17
+ import torch
18
+ from torch import Tensor, device, dtype, nn
19
+ import torch.utils.checkpoint
20
+ from torch import nn
21
+ from torch.nn import CrossEntropyLoss
22
+ import torch.nn.functional as F
23
+
24
+ from timm.models.layers import drop_path
25
+ from transformers.activations import ACT2FN
26
+ from transformers.file_utils import (
27
+ ModelOutput,
28
+ )
29
+ from transformers.modeling_outputs import (
30
+ BaseModelOutputWithPastAndCrossAttentions,
31
+ BaseModelOutputWithPoolingAndCrossAttentions,
32
+ CausalLMOutputWithCrossAttentions,
33
+ MaskedLMOutput,
34
+ MultipleChoiceModelOutput,
35
+ NextSentencePredictorOutput,
36
+ QuestionAnsweringModelOutput,
37
+ SequenceClassifierOutput,
38
+ TokenClassifierOutput,
39
+ )
40
+ from transformers.modeling_utils import (
41
+ PreTrainedModel,
42
+ apply_chunking_to_forward,
43
+ find_pruneable_heads_and_indices,
44
+ prune_linear_layer,
45
+ )
46
+ from transformers.models.bert.configuration_bert import BertConfig
47
+
48
+ import logging
49
+ logger = logging.getLogger(__name__)
50
+
51
+
52
+ class BertEmbeddings(nn.Module):
53
+ """Construct the embeddings from word and position embeddings."""
54
+
55
+ def __init__(self, config):
56
+ super().__init__()
57
+ self.word_embeddings = nn.Embedding(
58
+ config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id
59
+ )
60
+ self.position_embeddings = nn.Embedding(
61
+ config.max_position_embeddings, config.hidden_size
62
+ )
63
+
64
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
65
+ # any TensorFlow checkpoint file
66
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
67
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
68
+
69
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
70
+ self.register_buffer(
71
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))
72
+ )
73
+ self.position_embedding_type = getattr(
74
+ config, "position_embedding_type", "absolute"
75
+ )
76
+
77
+ self.config = config
78
+
79
+ def forward(
80
+ self,
81
+ input_ids=None,
82
+ position_ids=None,
83
+ query_embeds=None,
84
+ past_key_values_length=0,
85
+ ):
86
+ if input_ids is not None:
87
+ seq_length = input_ids.size()[1]
88
+ else:
89
+ seq_length = 0
90
+
91
+ if position_ids is None:
92
+ position_ids = self.position_ids[
93
+ :, past_key_values_length : seq_length + past_key_values_length
94
+ ].clone()
95
+
96
+ if input_ids is not None:
97
+ embeddings = self.word_embeddings(input_ids)
98
+ if self.position_embedding_type == "absolute":
99
+ position_embeddings = self.position_embeddings(position_ids)
100
+ embeddings = embeddings + position_embeddings
101
+
102
+ if query_embeds is not None:
103
+ embeddings = torch.cat((query_embeds, embeddings), dim=1)
104
+ else:
105
+ embeddings = query_embeds
106
+
107
+ embeddings = self.LayerNorm(embeddings)
108
+ embeddings = self.dropout(embeddings)
109
+ return embeddings
110
+
111
+
112
+ class BertSelfAttention(nn.Module):
113
+ def __init__(self, config, is_cross_attention):
114
+ super().__init__()
115
+ self.config = config
116
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
117
+ config, "embedding_size"
118
+ ):
119
+ raise ValueError(
120
+ "The hidden size (%d) is not a multiple of the number of attention "
121
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads)
122
+ )
123
+
124
+ self.num_attention_heads = config.num_attention_heads
125
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
126
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
127
+
128
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
129
+ if is_cross_attention:
130
+ self.key = nn.Linear(config.encoder_width, self.all_head_size)
131
+ self.value = nn.Linear(config.encoder_width, self.all_head_size)
132
+ else:
133
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
134
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
135
+
136
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
137
+ self.position_embedding_type = getattr(
138
+ config, "position_embedding_type", "absolute"
139
+ )
140
+ if (
141
+ self.position_embedding_type == "relative_key"
142
+ or self.position_embedding_type == "relative_key_query"
143
+ ):
144
+ self.max_position_embeddings = config.max_position_embeddings
145
+ self.distance_embedding = nn.Embedding(
146
+ 2 * config.max_position_embeddings - 1, self.attention_head_size
147
+ )
148
+ self.save_attention = False
149
+
150
+ def save_attn_gradients(self, attn_gradients):
151
+ self.attn_gradients = attn_gradients
152
+
153
+ def get_attn_gradients(self):
154
+ return self.attn_gradients
155
+
156
+ def save_attention_map(self, attention_map):
157
+ self.attention_map = attention_map
158
+
159
+ def get_attention_map(self):
160
+ return self.attention_map
161
+
162
+ def transpose_for_scores(self, x):
163
+ new_x_shape = x.size()[:-1] + (
164
+ self.num_attention_heads,
165
+ self.attention_head_size,
166
+ )
167
+ x = x.view(*new_x_shape)
168
+ return x.permute(0, 2, 1, 3)
169
+
170
+ def forward(
171
+ self,
172
+ hidden_states,
173
+ attention_mask=None,
174
+ head_mask=None,
175
+ encoder_hidden_states=None,
176
+ encoder_attention_mask=None,
177
+ past_key_value=None,
178
+ output_attentions=False,
179
+ ):
180
+
181
+ # If this is instantiated as a cross-attention module, the keys
182
+ # and values come from an encoder; the attention mask needs to be
183
+ # such that the encoder's padding tokens are not attended to.
184
+ is_cross_attention = encoder_hidden_states is not None
185
+
186
+ if is_cross_attention:
187
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
188
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
189
+ attention_mask = encoder_attention_mask
190
+ elif past_key_value is not None:
191
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
192
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
193
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
194
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
195
+ else:
196
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
197
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
198
+
199
+ mixed_query_layer = self.query(hidden_states)
200
+
201
+ query_layer = self.transpose_for_scores(mixed_query_layer)
202
+
203
+ past_key_value = (key_layer, value_layer)
204
+
205
+ # Take the dot product between "query" and "key" to get the raw attention scores.
206
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
207
+
208
+ if (
209
+ self.position_embedding_type == "relative_key"
210
+ or self.position_embedding_type == "relative_key_query"
211
+ ):
212
+ seq_length = hidden_states.size()[1]
213
+ position_ids_l = torch.arange(
214
+ seq_length, dtype=torch.long, device=hidden_states.device
215
+ ).view(-1, 1)
216
+ position_ids_r = torch.arange(
217
+ seq_length, dtype=torch.long, device=hidden_states.device
218
+ ).view(1, -1)
219
+ distance = position_ids_l - position_ids_r
220
+ positional_embedding = self.distance_embedding(
221
+ distance + self.max_position_embeddings - 1
222
+ )
223
+ positional_embedding = positional_embedding.to(
224
+ dtype=query_layer.dtype
225
+ ) # fp16 compatibility
226
+
227
+ if self.position_embedding_type == "relative_key":
228
+ relative_position_scores = torch.einsum(
229
+ "bhld,lrd->bhlr", query_layer, positional_embedding
230
+ )
231
+ attention_scores = attention_scores + relative_position_scores
232
+ elif self.position_embedding_type == "relative_key_query":
233
+ relative_position_scores_query = torch.einsum(
234
+ "bhld,lrd->bhlr", query_layer, positional_embedding
235
+ )
236
+ relative_position_scores_key = torch.einsum(
237
+ "bhrd,lrd->bhlr", key_layer, positional_embedding
238
+ )
239
+ attention_scores = (
240
+ attention_scores
241
+ + relative_position_scores_query
242
+ + relative_position_scores_key
243
+ )
244
+
245
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
246
+ if attention_mask is not None:
247
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
248
+ attention_scores = attention_scores + attention_mask
249
+
250
+ # Normalize the attention scores to probabilities.
251
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
252
+
253
+ if is_cross_attention and self.save_attention:
254
+ self.save_attention_map(attention_probs)
255
+ attention_probs.register_hook(self.save_attn_gradients)
256
+
257
+ # This is actually dropping out entire tokens to attend to, which might
258
+ # seem a bit unusual, but is taken from the original Transformer paper.
259
+ attention_probs_dropped = self.dropout(attention_probs)
260
+
261
+ # Mask heads if we want to
262
+ if head_mask is not None:
263
+ attention_probs_dropped = attention_probs_dropped * head_mask
264
+
265
+ context_layer = torch.matmul(attention_probs_dropped, value_layer)
266
+
267
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
268
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
269
+ context_layer = context_layer.view(*new_context_layer_shape)
270
+
271
+ outputs = (
272
+ (context_layer, attention_probs) if output_attentions else (context_layer,)
273
+ )
274
+
275
+ outputs = outputs + (past_key_value,)
276
+ return outputs
277
+
278
+
279
+ class DropPath(nn.Module):
280
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
281
+ """
282
+ def __init__(self, drop_prob=None):
283
+ super(DropPath, self).__init__()
284
+ self.drop_prob = drop_prob
285
+
286
+ def forward(self, x):
287
+ return drop_path(x, self.drop_prob, self.training)
288
+
289
+ def extra_repr(self) -> str:
290
+ return 'p={}'.format(self.drop_prob)
291
+
292
+
293
+ class BertSelfOutput(nn.Module):
294
+ def __init__(self, config, drop_path=0.):
295
+ super().__init__()
296
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
297
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
298
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
299
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
300
+
301
+ def forward(self, hidden_states, input_tensor):
302
+ hidden_states = self.dense(hidden_states)
303
+ hidden_states = self.dropout(hidden_states)
304
+ hidden_states = self.drop_path(hidden_states)
305
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
306
+ return hidden_states
307
+
308
+
309
+ class BertAttention(nn.Module):
310
+ def __init__(self, config, is_cross_attention=False, drop_path=0.,):
311
+ super().__init__()
312
+ self.self = BertSelfAttention(config, is_cross_attention)
313
+ self.output = BertSelfOutput(config, drop_path=drop_path)
314
+ self.pruned_heads = set()
315
+
316
+ def prune_heads(self, heads):
317
+ if len(heads) == 0:
318
+ return
319
+ heads, index = find_pruneable_heads_and_indices(
320
+ heads,
321
+ self.self.num_attention_heads,
322
+ self.self.attention_head_size,
323
+ self.pruned_heads,
324
+ )
325
+
326
+ # Prune linear layers
327
+ self.self.query = prune_linear_layer(self.self.query, index)
328
+ self.self.key = prune_linear_layer(self.self.key, index)
329
+ self.self.value = prune_linear_layer(self.self.value, index)
330
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
331
+
332
+ # Update hyper params and store pruned heads
333
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
334
+ self.self.all_head_size = (
335
+ self.self.attention_head_size * self.self.num_attention_heads
336
+ )
337
+ self.pruned_heads = self.pruned_heads.union(heads)
338
+
339
+ def forward(
340
+ self,
341
+ hidden_states,
342
+ attention_mask=None,
343
+ head_mask=None,
344
+ encoder_hidden_states=None,
345
+ encoder_attention_mask=None,
346
+ past_key_value=None,
347
+ output_attentions=False,
348
+ ):
349
+ self_outputs = self.self(
350
+ hidden_states,
351
+ attention_mask,
352
+ head_mask,
353
+ encoder_hidden_states,
354
+ encoder_attention_mask,
355
+ past_key_value,
356
+ output_attentions,
357
+ )
358
+ attention_output = self.output(self_outputs[0], hidden_states)
359
+
360
+ outputs = (attention_output,) + self_outputs[
361
+ 1:
362
+ ] # add attentions if we output them
363
+ return outputs
364
+
365
+
366
+ class BertIntermediate(nn.Module):
367
+ def __init__(self, config):
368
+ super().__init__()
369
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
370
+ if isinstance(config.hidden_act, str):
371
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
372
+ else:
373
+ self.intermediate_act_fn = config.hidden_act
374
+
375
+ def forward(self, hidden_states):
376
+ hidden_states = self.dense(hidden_states)
377
+ hidden_states = self.intermediate_act_fn(hidden_states)
378
+ return hidden_states
379
+
380
+
381
+ class BertOutput(nn.Module):
382
+ def __init__(self, config, drop_path=0.):
383
+ super().__init__()
384
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
385
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
386
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
387
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
388
+
389
+ def forward(self, hidden_states, input_tensor):
390
+ hidden_states = self.dense(hidden_states)
391
+ hidden_states = self.dropout(hidden_states)
392
+ hidden_states = self.drop_path(hidden_states)
393
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
394
+ return hidden_states
395
+
396
+
397
+ class BertLayer(nn.Module):
398
+ def __init__(self, config, layer_num):
399
+ super().__init__()
400
+ self.config = config
401
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
402
+ self.seq_len_dim = 1
403
+ drop_path = config.drop_path_list[layer_num]
404
+ self.attention = BertAttention(config, drop_path=drop_path)
405
+ self.layer_num = layer_num
406
+ if (
407
+ self.config.add_cross_attention
408
+ and layer_num % self.config.cross_attention_freq == 0
409
+ ):
410
+ self.crossattention = BertAttention(
411
+ config, is_cross_attention=self.config.add_cross_attention,
412
+ drop_path=drop_path
413
+ )
414
+ self.has_cross_attention = True
415
+ else:
416
+ self.has_cross_attention = False
417
+ self.intermediate = BertIntermediate(config)
418
+ self.output = BertOutput(config, drop_path=drop_path)
419
+
420
+ self.intermediate_query = BertIntermediate(config)
421
+ self.output_query = BertOutput(config, drop_path=drop_path)
422
+
423
+ def forward(
424
+ self,
425
+ hidden_states,
426
+ attention_mask=None,
427
+ head_mask=None,
428
+ encoder_hidden_states=None,
429
+ encoder_attention_mask=None,
430
+ past_key_value=None,
431
+ output_attentions=False,
432
+ query_length=0,
433
+ ):
434
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
435
+ self_attn_past_key_value = (
436
+ past_key_value[:2] if past_key_value is not None else None
437
+ )
438
+ self_attention_outputs = self.attention(
439
+ hidden_states,
440
+ attention_mask,
441
+ head_mask,
442
+ output_attentions=output_attentions,
443
+ past_key_value=self_attn_past_key_value,
444
+ )
445
+ attention_output = self_attention_outputs[0]
446
+ outputs = self_attention_outputs[1:-1]
447
+
448
+ present_key_value = self_attention_outputs[-1]
449
+
450
+ if query_length > 0:
451
+ query_attention_output = attention_output[:, :query_length, :]
452
+
453
+ if self.has_cross_attention:
454
+ assert (
455
+ encoder_hidden_states is not None
456
+ ), "encoder_hidden_states must be given for cross-attention layers"
457
+ cross_attention_outputs = self.crossattention(
458
+ query_attention_output,
459
+ attention_mask,
460
+ head_mask,
461
+ encoder_hidden_states,
462
+ encoder_attention_mask,
463
+ output_attentions=output_attentions,
464
+ )
465
+ query_attention_output = cross_attention_outputs[0]
466
+ outputs = (
467
+ outputs + cross_attention_outputs[1:-1]
468
+ ) # add cross attentions if we output attention weights
469
+
470
+ layer_output = apply_chunking_to_forward(
471
+ self.feed_forward_chunk_query,
472
+ self.chunk_size_feed_forward,
473
+ self.seq_len_dim,
474
+ query_attention_output,
475
+ )
476
+ if attention_output.shape[1] > query_length:
477
+ layer_output_text = apply_chunking_to_forward(
478
+ self.feed_forward_chunk,
479
+ self.chunk_size_feed_forward,
480
+ self.seq_len_dim,
481
+ attention_output[:, query_length:, :],
482
+ )
483
+ layer_output = torch.cat([layer_output, layer_output_text], dim=1)
484
+ else:
485
+ layer_output = apply_chunking_to_forward(
486
+ self.feed_forward_chunk,
487
+ self.chunk_size_feed_forward,
488
+ self.seq_len_dim,
489
+ attention_output,
490
+ )
491
+ outputs = (layer_output,) + outputs
492
+
493
+ outputs = outputs + (present_key_value,)
494
+
495
+ return outputs
496
+
497
+ def feed_forward_chunk(self, attention_output):
498
+ intermediate_output = self.intermediate(attention_output)
499
+ layer_output = self.output(intermediate_output, attention_output)
500
+ return layer_output
501
+
502
+ def feed_forward_chunk_query(self, attention_output):
503
+ intermediate_output = self.intermediate_query(attention_output)
504
+ layer_output = self.output_query(intermediate_output, attention_output)
505
+ return layer_output
506
+
507
+
508
+ class BertEncoder(nn.Module):
509
+ def __init__(self, config):
510
+ super().__init__()
511
+ self.config = config
512
+ self.layer = nn.ModuleList(
513
+ [BertLayer(config, i) for i in range(config.num_hidden_layers)]
514
+ )
515
+
516
+ def forward(
517
+ self,
518
+ hidden_states,
519
+ attention_mask=None,
520
+ head_mask=None,
521
+ encoder_hidden_states=None,
522
+ encoder_attention_mask=None,
523
+ past_key_values=None,
524
+ use_cache=None,
525
+ output_attentions=False,
526
+ output_hidden_states=False,
527
+ return_dict=True,
528
+ query_length=0,
529
+ ):
530
+ all_hidden_states = () if output_hidden_states else None
531
+ all_self_attentions = () if output_attentions else None
532
+ all_cross_attentions = (
533
+ () if output_attentions and self.config.add_cross_attention else None
534
+ )
535
+
536
+ next_decoder_cache = () if use_cache else None
537
+
538
+ for i in range(self.config.num_hidden_layers):
539
+ layer_module = self.layer[i]
540
+ if output_hidden_states:
541
+ all_hidden_states = all_hidden_states + (hidden_states,)
542
+
543
+ layer_head_mask = head_mask[i] if head_mask is not None else None
544
+ past_key_value = past_key_values[i] if past_key_values is not None else None
545
+
546
+ if getattr(self.config, "gradient_checkpointing", False) and self.training:
547
+
548
+ if use_cache:
549
+ logger.warn(
550
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
551
+ )
552
+ use_cache = False
553
+
554
+ def create_custom_forward(module):
555
+ def custom_forward(*inputs):
556
+ return module(
557
+ *inputs, past_key_value, output_attentions, query_length
558
+ )
559
+
560
+ return custom_forward
561
+
562
+ layer_outputs = torch.utils.checkpoint.checkpoint(
563
+ create_custom_forward(layer_module),
564
+ hidden_states,
565
+ attention_mask,
566
+ layer_head_mask,
567
+ encoder_hidden_states,
568
+ encoder_attention_mask,
569
+ )
570
+ else:
571
+ layer_outputs = layer_module(
572
+ hidden_states,
573
+ attention_mask,
574
+ layer_head_mask,
575
+ encoder_hidden_states,
576
+ encoder_attention_mask,
577
+ past_key_value,
578
+ output_attentions,
579
+ query_length,
580
+ )
581
+
582
+ hidden_states = layer_outputs[0]
583
+ if use_cache:
584
+ next_decoder_cache += (layer_outputs[-1],)
585
+ if output_attentions:
586
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
587
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
588
+
589
+ if output_hidden_states:
590
+ all_hidden_states = all_hidden_states + (hidden_states,)
591
+
592
+ if not return_dict:
593
+ return tuple(
594
+ v
595
+ for v in [
596
+ hidden_states,
597
+ next_decoder_cache,
598
+ all_hidden_states,
599
+ all_self_attentions,
600
+ all_cross_attentions,
601
+ ]
602
+ if v is not None
603
+ )
604
+ return BaseModelOutputWithPastAndCrossAttentions(
605
+ last_hidden_state=hidden_states,
606
+ past_key_values=next_decoder_cache,
607
+ hidden_states=all_hidden_states,
608
+ attentions=all_self_attentions,
609
+ cross_attentions=all_cross_attentions,
610
+ )
611
+
612
+
613
+ class BertPooler(nn.Module):
614
+ def __init__(self, config):
615
+ super().__init__()
616
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
617
+ self.activation = nn.Tanh()
618
+
619
+ def forward(self, hidden_states):
620
+ # We "pool" the model by simply taking the hidden state corresponding
621
+ # to the first token.
622
+ first_token_tensor = hidden_states[:, 0]
623
+ pooled_output = self.dense(first_token_tensor)
624
+ pooled_output = self.activation(pooled_output)
625
+ return pooled_output
626
+
627
+
628
+ class BertPredictionHeadTransform(nn.Module):
629
+ def __init__(self, config):
630
+ super().__init__()
631
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
632
+ if isinstance(config.hidden_act, str):
633
+ self.transform_act_fn = ACT2FN[config.hidden_act]
634
+ else:
635
+ self.transform_act_fn = config.hidden_act
636
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
637
+
638
+ def forward(self, hidden_states):
639
+ hidden_states = self.dense(hidden_states)
640
+ hidden_states = self.transform_act_fn(hidden_states)
641
+ hidden_states = self.LayerNorm(hidden_states)
642
+ return hidden_states
643
+
644
+
645
+ class BertLMPredictionHead(nn.Module):
646
+ def __init__(self, config):
647
+ super().__init__()
648
+ self.transform = BertPredictionHeadTransform(config)
649
+
650
+ # The output weights are the same as the input embeddings, but there is
651
+ # an output-only bias for each token.
652
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
653
+
654
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
655
+
656
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
657
+ self.decoder.bias = self.bias
658
+
659
+ def forward(self, hidden_states):
660
+ hidden_states = self.transform(hidden_states)
661
+ hidden_states = self.decoder(hidden_states)
662
+ return hidden_states
663
+
664
+
665
+ class BertOnlyMLMHead(nn.Module):
666
+ def __init__(self, config):
667
+ super().__init__()
668
+ self.predictions = BertLMPredictionHead(config)
669
+
670
+ def forward(self, sequence_output):
671
+ prediction_scores = self.predictions(sequence_output)
672
+ return prediction_scores
673
+
674
+
675
+ class BertPreTrainedModel(PreTrainedModel):
676
+ """
677
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
678
+ models.
679
+ """
680
+
681
+ config_class = BertConfig
682
+ base_model_prefix = "bert"
683
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
684
+
685
+ def _init_weights(self, module):
686
+ """Initialize the weights"""
687
+ if isinstance(module, (nn.Linear, nn.Embedding)):
688
+ # Slightly different from the TF version which uses truncated_normal for initialization
689
+ # cf https://github.com/pytorch/pytorch/pull/5617
690
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
691
+ elif isinstance(module, nn.LayerNorm):
692
+ module.bias.data.zero_()
693
+ module.weight.data.fill_(1.0)
694
+ if isinstance(module, nn.Linear) and module.bias is not None:
695
+ module.bias.data.zero_()
696
+
697
+
698
+ class BertModel(BertPreTrainedModel):
699
+ """
700
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
701
+ cross-attention is added between the self-attention layers, following the architecture described in `Attention is
702
+ all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
703
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
704
+ argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
705
+ input to the forward pass.
706
+ """
707
+
708
+ def __init__(self, config, add_pooling_layer=False):
709
+ super().__init__(config)
710
+ self.config = config
711
+
712
+ self.embeddings = BertEmbeddings(config)
713
+
714
+ self.encoder = BertEncoder(config)
715
+
716
+ self.pooler = BertPooler(config) if add_pooling_layer else None
717
+
718
+ self.init_weights()
719
+
720
+ def get_input_embeddings(self):
721
+ return self.embeddings.word_embeddings
722
+
723
+ def set_input_embeddings(self, value):
724
+ self.embeddings.word_embeddings = value
725
+
726
+ def _prune_heads(self, heads_to_prune):
727
+ """
728
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
729
+ class PreTrainedModel
730
+ """
731
+ for layer, heads in heads_to_prune.items():
732
+ self.encoder.layer[layer].attention.prune_heads(heads)
733
+
734
+ def get_extended_attention_mask(
735
+ self,
736
+ attention_mask: Tensor,
737
+ input_shape: Tuple[int],
738
+ device: device,
739
+ is_decoder: bool,
740
+ has_query: bool = False,
741
+ ) -> Tensor:
742
+ """
743
+ Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
744
+
745
+ Arguments:
746
+ attention_mask (:obj:`torch.Tensor`):
747
+ Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
748
+ input_shape (:obj:`Tuple[int]`):
749
+ The shape of the input to the model.
750
+ device: (:obj:`torch.device`):
751
+ The device of the input to the model.
752
+
753
+ Returns:
754
+ :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
755
+ """
756
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
757
+ # ourselves in which case we just need to make it broadcastable to all heads.
758
+ if attention_mask.dim() == 3:
759
+ extended_attention_mask = attention_mask[:, None, :, :]
760
+ elif attention_mask.dim() == 2:
761
+ # Provided a padding mask of dimensions [batch_size, seq_length]
762
+ # - if the model is a decoder, apply a causal mask in addition to the padding mask
763
+ # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
764
+ if is_decoder:
765
+ batch_size, seq_length = input_shape
766
+
767
+ seq_ids = torch.arange(seq_length, device=device)
768
+ causal_mask = (
769
+ seq_ids[None, None, :].repeat(batch_size, seq_length, 1)
770
+ <= seq_ids[None, :, None]
771
+ )
772
+
773
+ # add a prefix ones mask to the causal mask
774
+ # causal and attention masks must have same type with pytorch version < 1.3
775
+ causal_mask = causal_mask.to(attention_mask.dtype)
776
+
777
+ if causal_mask.shape[1] < attention_mask.shape[1]:
778
+ prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
779
+ if has_query: # UniLM style attention mask
780
+ causal_mask = torch.cat(
781
+ [
782
+ torch.zeros(
783
+ (batch_size, prefix_seq_len, seq_length),
784
+ device=device,
785
+ dtype=causal_mask.dtype,
786
+ ),
787
+ causal_mask,
788
+ ],
789
+ axis=1,
790
+ )
791
+ causal_mask = torch.cat(
792
+ [
793
+ torch.ones(
794
+ (batch_size, causal_mask.shape[1], prefix_seq_len),
795
+ device=device,
796
+ dtype=causal_mask.dtype,
797
+ ),
798
+ causal_mask,
799
+ ],
800
+ axis=-1,
801
+ )
802
+ extended_attention_mask = (
803
+ causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
804
+ )
805
+ else:
806
+ extended_attention_mask = attention_mask[:, None, None, :]
807
+ else:
808
+ raise ValueError(
809
+ "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
810
+ input_shape, attention_mask.shape
811
+ )
812
+ )
813
+
814
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
815
+ # masked positions, this operation will create a tensor which is 0.0 for
816
+ # positions we want to attend and -10000.0 for masked positions.
817
+ # Since we are adding it to the raw scores before the softmax, this is
818
+ # effectively the same as removing these entirely.
819
+ extended_attention_mask = extended_attention_mask.to(
820
+ dtype=self.dtype
821
+ ) # fp16 compatibility
822
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
823
+ return extended_attention_mask
824
+
825
+ def forward(
826
+ self,
827
+ input_ids=None,
828
+ attention_mask=None,
829
+ position_ids=None,
830
+ head_mask=None,
831
+ query_embeds=None,
832
+ encoder_hidden_states=None,
833
+ encoder_attention_mask=None,
834
+ past_key_values=None,
835
+ use_cache=None,
836
+ output_attentions=None,
837
+ output_hidden_states=None,
838
+ return_dict=None,
839
+ is_decoder=False,
840
+ ):
841
+ r"""
842
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
843
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
844
+ the model is configured as a decoder.
845
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
846
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
847
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
848
+ - 1 for tokens that are **not masked**,
849
+ - 0 for tokens that are **masked**.
850
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
851
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
852
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
853
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
854
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
855
+ use_cache (:obj:`bool`, `optional`):
856
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
857
+ decoding (see :obj:`past_key_values`).
858
+ """
859
+ output_attentions = (
860
+ output_attentions
861
+ if output_attentions is not None
862
+ else self.config.output_attentions
863
+ )
864
+ output_hidden_states = (
865
+ output_hidden_states
866
+ if output_hidden_states is not None
867
+ else self.config.output_hidden_states
868
+ )
869
+ return_dict = (
870
+ return_dict if return_dict is not None else self.config.use_return_dict
871
+ )
872
+
873
+ # use_cache = use_cache if use_cache is not None else self.config.use_cache
874
+
875
+ if input_ids is None:
876
+ assert (
877
+ query_embeds is not None
878
+ ), "You have to specify query_embeds when input_ids is None"
879
+
880
+ # past_key_values_length
881
+ past_key_values_length = (
882
+ past_key_values[0][0].shape[2] - self.config.query_length
883
+ if past_key_values is not None
884
+ else 0
885
+ )
886
+
887
+ query_length = query_embeds.shape[1] if query_embeds is not None else 0
888
+
889
+ embedding_output = self.embeddings(
890
+ input_ids=input_ids,
891
+ position_ids=position_ids,
892
+ query_embeds=query_embeds,
893
+ past_key_values_length=past_key_values_length,
894
+ )
895
+
896
+ input_shape = embedding_output.size()[:-1]
897
+ batch_size, seq_length = input_shape
898
+ device = embedding_output.device
899
+
900
+ if attention_mask is None:
901
+ attention_mask = torch.ones(
902
+ ((batch_size, seq_length + past_key_values_length)), device=device
903
+ )
904
+
905
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
906
+ # ourselves in which case we just need to make it broadcastable to all heads.
907
+ if is_decoder:
908
+ extended_attention_mask = self.get_extended_attention_mask(
909
+ attention_mask,
910
+ input_ids.shape,
911
+ device,
912
+ is_decoder,
913
+ has_query=(query_embeds is not None),
914
+ )
915
+ else:
916
+ extended_attention_mask = self.get_extended_attention_mask(
917
+ attention_mask, input_shape, device, is_decoder
918
+ )
919
+
920
+ # If a 2D or 3D attention mask is provided for the cross-attention
921
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
922
+ if encoder_hidden_states is not None:
923
+ if type(encoder_hidden_states) == list:
924
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[
925
+ 0
926
+ ].size()
927
+ else:
928
+ (
929
+ encoder_batch_size,
930
+ encoder_sequence_length,
931
+ _,
932
+ ) = encoder_hidden_states.size()
933
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
934
+
935
+ if type(encoder_attention_mask) == list:
936
+ encoder_extended_attention_mask = [
937
+ self.invert_attention_mask(mask) for mask in encoder_attention_mask
938
+ ]
939
+ elif encoder_attention_mask is None:
940
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
941
+ encoder_extended_attention_mask = self.invert_attention_mask(
942
+ encoder_attention_mask
943
+ )
944
+ else:
945
+ encoder_extended_attention_mask = self.invert_attention_mask(
946
+ encoder_attention_mask
947
+ )
948
+ else:
949
+ encoder_extended_attention_mask = None
950
+
951
+ # Prepare head mask if needed
952
+ # 1.0 in head_mask indicate we keep the head
953
+ # attention_probs has shape bsz x n_heads x N x N
954
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
955
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
956
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
957
+
958
+ encoder_outputs = self.encoder(
959
+ embedding_output,
960
+ attention_mask=extended_attention_mask,
961
+ head_mask=head_mask,
962
+ encoder_hidden_states=encoder_hidden_states,
963
+ encoder_attention_mask=encoder_extended_attention_mask,
964
+ past_key_values=past_key_values,
965
+ use_cache=use_cache,
966
+ output_attentions=output_attentions,
967
+ output_hidden_states=output_hidden_states,
968
+ return_dict=return_dict,
969
+ query_length=query_length,
970
+ )
971
+ sequence_output = encoder_outputs[0]
972
+ pooled_output = (
973
+ self.pooler(sequence_output) if self.pooler is not None else None
974
+ )
975
+
976
+ if not return_dict:
977
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
978
+
979
+ return BaseModelOutputWithPoolingAndCrossAttentions(
980
+ last_hidden_state=sequence_output,
981
+ pooler_output=pooled_output,
982
+ past_key_values=encoder_outputs.past_key_values,
983
+ hidden_states=encoder_outputs.hidden_states,
984
+ attentions=encoder_outputs.attentions,
985
+ cross_attentions=encoder_outputs.cross_attentions,
986
+ )
987
+
988
+
989
+ class BertLMHeadModel(BertPreTrainedModel):
990
+
991
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
992
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
993
+
994
+ def __init__(self, config):
995
+ super().__init__(config)
996
+
997
+ self.bert = BertModel(config, add_pooling_layer=False)
998
+ self.cls = BertOnlyMLMHead(config)
999
+
1000
+ self.init_weights()
1001
+
1002
+ def get_output_embeddings(self):
1003
+ return self.cls.predictions.decoder
1004
+
1005
+ def set_output_embeddings(self, new_embeddings):
1006
+ self.cls.predictions.decoder = new_embeddings
1007
+
1008
+ def forward(
1009
+ self,
1010
+ input_ids=None,
1011
+ attention_mask=None,
1012
+ position_ids=None,
1013
+ head_mask=None,
1014
+ query_embeds=None,
1015
+ encoder_hidden_states=None,
1016
+ encoder_attention_mask=None,
1017
+ labels=None,
1018
+ past_key_values=None,
1019
+ use_cache=True,
1020
+ output_attentions=None,
1021
+ output_hidden_states=None,
1022
+ return_dict=None,
1023
+ return_logits=False,
1024
+ is_decoder=True,
1025
+ reduction="mean",
1026
+ ):
1027
+ r"""
1028
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
1029
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
1030
+ the model is configured as a decoder.
1031
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1032
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
1033
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
1034
+ - 1 for tokens that are **not masked**,
1035
+ - 0 for tokens that are **masked**.
1036
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1037
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
1038
+ ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
1039
+ ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
1040
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
1041
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
1042
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
1043
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
1044
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
1045
+ use_cache (:obj:`bool`, `optional`):
1046
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
1047
+ decoding (see :obj:`past_key_values`).
1048
+ Returns:
1049
+ Example::
1050
+ >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
1051
+ >>> import torch
1052
+ >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
1053
+ >>> config = BertConfig.from_pretrained("bert-base-cased")
1054
+ >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
1055
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
1056
+ >>> outputs = model(**inputs)
1057
+ >>> prediction_logits = outputs.logits
1058
+ """
1059
+ return_dict = (
1060
+ return_dict if return_dict is not None else self.config.use_return_dict
1061
+ )
1062
+ if labels is not None:
1063
+ use_cache = False
1064
+ if past_key_values is not None:
1065
+ query_embeds = None
1066
+
1067
+ outputs = self.bert(
1068
+ input_ids,
1069
+ attention_mask=attention_mask,
1070
+ position_ids=position_ids,
1071
+ head_mask=head_mask,
1072
+ query_embeds=query_embeds,
1073
+ encoder_hidden_states=encoder_hidden_states,
1074
+ encoder_attention_mask=encoder_attention_mask,
1075
+ past_key_values=past_key_values,
1076
+ use_cache=use_cache,
1077
+ output_attentions=output_attentions,
1078
+ output_hidden_states=output_hidden_states,
1079
+ return_dict=return_dict,
1080
+ is_decoder=is_decoder,
1081
+ )
1082
+
1083
+ sequence_output = outputs[0]
1084
+ if query_embeds is not None:
1085
+ sequence_output = outputs[0][:, query_embeds.shape[1] :, :]
1086
+
1087
+ prediction_scores = self.cls(sequence_output)
1088
+
1089
+ if return_logits:
1090
+ return prediction_scores[:, :-1, :].contiguous()
1091
+
1092
+ lm_loss = None
1093
+ if labels is not None:
1094
+ # we are doing next-token prediction; shift prediction scores and input ids by one
1095
+ shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
1096
+ labels = labels[:, 1:].contiguous()
1097
+ loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1)
1098
+ lm_loss = loss_fct(
1099
+ shifted_prediction_scores.view(-1, self.config.vocab_size),
1100
+ labels.view(-1),
1101
+ )
1102
+ if reduction == "none":
1103
+ lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1)
1104
+
1105
+ if not return_dict:
1106
+ output = (prediction_scores,) + outputs[2:]
1107
+ return ((lm_loss,) + output) if lm_loss is not None else output
1108
+
1109
+ return CausalLMOutputWithCrossAttentions(
1110
+ loss=lm_loss,
1111
+ logits=prediction_scores,
1112
+ past_key_values=outputs.past_key_values,
1113
+ hidden_states=outputs.hidden_states,
1114
+ attentions=outputs.attentions,
1115
+ cross_attentions=outputs.cross_attentions,
1116
+ )
1117
+
1118
+ def prepare_inputs_for_generation(
1119
+ self, input_ids, query_embeds, past=None, attention_mask=None, **model_kwargs
1120
+ ):
1121
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
1122
+ if attention_mask is None:
1123
+ attention_mask = input_ids.new_ones(input_ids.shape)
1124
+ query_mask = input_ids.new_ones(query_embeds.shape[:-1])
1125
+ attention_mask = torch.cat([query_mask, attention_mask], dim=-1)
1126
+
1127
+ # cut decoder_input_ids if past is used
1128
+ if past is not None:
1129
+ input_ids = input_ids[:, -1:]
1130
+
1131
+ return {
1132
+ "input_ids": input_ids,
1133
+ "query_embeds": query_embeds,
1134
+ "attention_mask": attention_mask,
1135
+ "past_key_values": past,
1136
+ "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
1137
+ "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
1138
+ "is_decoder": True,
1139
+ }
1140
+
1141
+ def _reorder_cache(self, past, beam_idx):
1142
+ reordered_past = ()
1143
+ for layer_past in past:
1144
+ reordered_past += (
1145
+ tuple(
1146
+ past_state.index_select(0, beam_idx) for past_state in layer_past
1147
+ ),
1148
+ )
1149
+ return reordered_past
1150
+
1151
+
1152
+ class BertForMaskedLM(BertPreTrainedModel):
1153
+
1154
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
1155
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
1156
+
1157
+ def __init__(self, config):
1158
+ super().__init__(config)
1159
+
1160
+ self.bert = BertModel(config, add_pooling_layer=False)
1161
+ self.cls = BertOnlyMLMHead(config)
1162
+
1163
+ self.init_weights()
1164
+
1165
+ def get_output_embeddings(self):
1166
+ return self.cls.predictions.decoder
1167
+
1168
+ def set_output_embeddings(self, new_embeddings):
1169
+ self.cls.predictions.decoder = new_embeddings
1170
+
1171
+ def forward(
1172
+ self,
1173
+ input_ids=None,
1174
+ attention_mask=None,
1175
+ position_ids=None,
1176
+ head_mask=None,
1177
+ query_embeds=None,
1178
+ encoder_hidden_states=None,
1179
+ encoder_attention_mask=None,
1180
+ labels=None,
1181
+ output_attentions=None,
1182
+ output_hidden_states=None,
1183
+ return_dict=None,
1184
+ return_logits=False,
1185
+ is_decoder=False,
1186
+ ):
1187
+ r"""
1188
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1189
+ Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
1190
+ config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
1191
+ (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
1192
+ """
1193
+
1194
+ return_dict = (
1195
+ return_dict if return_dict is not None else self.config.use_return_dict
1196
+ )
1197
+
1198
+ outputs = self.bert(
1199
+ input_ids,
1200
+ attention_mask=attention_mask,
1201
+ position_ids=position_ids,
1202
+ head_mask=head_mask,
1203
+ query_embeds=query_embeds,
1204
+ encoder_hidden_states=encoder_hidden_states,
1205
+ encoder_attention_mask=encoder_attention_mask,
1206
+ output_attentions=output_attentions,
1207
+ output_hidden_states=output_hidden_states,
1208
+ return_dict=return_dict,
1209
+ is_decoder=is_decoder,
1210
+ )
1211
+
1212
+ if query_embeds is not None:
1213
+ sequence_output = outputs[0][:, query_embeds.shape[1] :, :]
1214
+ prediction_scores = self.cls(sequence_output)
1215
+
1216
+ if return_logits:
1217
+ return prediction_scores
1218
+
1219
+ masked_lm_loss = None
1220
+ if labels is not None:
1221
+ loss_fct = CrossEntropyLoss() # -100 index = padding token
1222
+ masked_lm_loss = loss_fct(
1223
+ prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)
1224
+ )
1225
+
1226
+ if not return_dict:
1227
+ output = (prediction_scores,) + outputs[2:]
1228
+ return (
1229
+ ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
1230
+ )
1231
+
1232
+ return MaskedLMOutput(
1233
+ loss=masked_lm_loss,
1234
+ logits=prediction_scores,
1235
+ hidden_states=outputs.hidden_states,
1236
+ attentions=outputs.attentions,
1237
+ )
1238
+
1239
+
1240
+ def build_qformer(num_query_token, vision_width,
1241
+ qformer_hidden_dropout_prob=0.1,
1242
+ qformer_attention_probs_dropout_prob=0.1,
1243
+ qformer_drop_path_rate=0.,
1244
+ bert_type="bert-base-uncased"
1245
+ ):
1246
+
1247
+ try:
1248
+ encoder_config = BertConfig.from_pretrained(bert_type, local_files_only=True)
1249
+ except:
1250
+ encoder_config = BertConfig.from_pretrained(bert_type)
1251
+ encoder_config.encoder_width = vision_width
1252
+ # insert cross-attention layer every other block
1253
+ encoder_config.add_cross_attention = True
1254
+ encoder_config.cross_attention_freq = 2
1255
+ encoder_config.query_length = num_query_token
1256
+ encoder_config.hidden_dropout_prob = qformer_hidden_dropout_prob
1257
+ encoder_config.attention_probs_dropout_prob = qformer_attention_probs_dropout_prob
1258
+ encoder_config.drop_path_list = [x.item() for x in torch.linspace(0, qformer_drop_path_rate, encoder_config.num_hidden_layers)]
1259
+ logger.info(f"Drop_path:{encoder_config.drop_path_list}")
1260
+ logger.info(encoder_config)
1261
+ Qformer = BertLMHeadModel.from_pretrained(
1262
+ bert_type, config=encoder_config, local_files_only=True
1263
+ )
1264
+ query_tokens = nn.Parameter(
1265
+ torch.zeros(1, num_query_token, encoder_config.hidden_size)
1266
+ )
1267
+ query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
1268
+ return Qformer, query_tokens
1269
+
1270
+
modeling_videochat2.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import logging
3
+ import torch
4
+ import torch.utils.checkpoint
5
+ from torch import nn
6
+ from torch.nn import MSELoss
7
+ from transformers.modeling_outputs import (
8
+ CausalLMOutputWithPast,
9
+ )
10
+ from typing import List, Optional, Tuple, Union
11
+ from torch.cuda.amp import autocast as autocast
12
+ from .modeling_base import BaseMLLM
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ class InternVideo2_VideoChat2(BaseMLLM):
18
+
19
+ def __init__(
20
+ self,
21
+ config
22
+ ):
23
+ super().__init__(config=config)
24
+
25
+ def forward(
26
+ self,
27
+ input_ids: torch.LongTensor = None,
28
+ attention_mask: Optional[torch.Tensor] = None,
29
+ labels: Optional[torch.LongTensor] = None,
30
+ image: Optional[torch.Tensor] = None,
31
+ video: Optional[torch.Tensor] = None,
32
+ instruction = None,
33
+ video_idx = None,
34
+ image_idx = None,
35
+ ):
36
+ # print('Model Forwarding')
37
+
38
+ if self.use_vision_regression_loss:
39
+ text_embeds, visual, visual_idx = self.pad_text_embeds(input_ids=input_ids, image=image,video=video, return_visual=True, video_idx=video_idx, image_idx=image_idx, instruction = instruction)
40
+ else:
41
+ text_embeds = self.pad_text_embeds(input_ids=input_ids, image=image, video=video, return_visual=False, video_idx=video_idx, image_idx=image_idx, instruction = instruction)
42
+
43
+ outputs = self.lm(
44
+ inputs_embeds=text_embeds,
45
+ attention_mask=attention_mask,
46
+ labels=labels,
47
+ output_hidden_states=True,
48
+ return_dict=True,
49
+ )
50
+
51
+ return outputs
52
+
53
+ def pad_text_embeds(
54
+ self,
55
+ input_ids: torch.LongTensor = None,
56
+ image: Optional[torch.Tensor] = None,
57
+ video: Optional[torch.Tensor] = None,
58
+ image_idx = None,
59
+ video_idx = None,
60
+ return_visual: bool = False,
61
+ instruction = None,
62
+ ):
63
+ # text_embeds
64
+ text_embeds = self.lm.get_input_embeddings()(input_ids.long()).detach()
65
+
66
+ visual = None
67
+ visual_idx = None
68
+
69
+ if image is not None:
70
+ B, T, C, H, W = image.shape
71
+ image = image.permute(0, 2, 1, 3, 4)
72
+ prompt_image_embeds = self.encode_vision(image, instruction=instruction)
73
+ visual = prompt_image_embeds
74
+ prompt_image_embeds = self.project_up(prompt_image_embeds)
75
+ prompt_image_embeds = prompt_image_embeds.view(-1, prompt_image_embeds.shape[-1])
76
+ visual_idx = image_idx
77
+ text_embeds[image_idx == 1] = text_embeds[image_idx == 1] * 0 + prompt_image_embeds.to(text_embeds.device)
78
+ elif video is not None:
79
+ if len(video.shape) == 5:
80
+ B, T, C, H, W = video.shape
81
+ N = 1
82
+ else:
83
+ B, N, T, C, H, W = video.shape
84
+ video = video.reshape(B*N, T, C, H, W).permute(0, 2, 1, 3, 4)
85
+ prompt_video_embeds = self.encode_vision(video, instruction=instruction)
86
+ visual = prompt_video_embeds
87
+ prompt_video_embeds = self.project_up(prompt_video_embeds)
88
+ prompt_video_embeds = prompt_video_embeds.view(-1, prompt_video_embeds.shape[-1])
89
+ visual_idx = video_idx
90
+ text_embeds[video_idx == 1] = text_embeds[video_idx == 1] * 0 + prompt_video_embeds.to(text_embeds.device).to(text_embeds.dtype)
91
+ else:
92
+ logger.warn(f"don't get visual input, input_ids: {input_ids}")
93
+
94
+ if return_visual:
95
+ return text_embeds, visual, visual_idx
96
+
97
+ return text_embeds
98
+
99
+
100
+ def encode_vision(
101
+ self,
102
+ image,
103
+ instruction
104
+ ):
105
+ device = image.device
106
+ B = image.shape[0]
107
+ T = image.shape[2]
108
+ use_image = True if T == 1 else False
109
+ image_embeds = self.vision_encoder(image, use_image=use_image)
110
+ C = image_embeds.shape[-1]
111
+ image_embeds = image_embeds.reshape(B, -1, C)
112
+ image_embeds = self.vision_layernorm(image_embeds).to(device) # [B, T*L, C]
113
+
114
+ image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device)
115
+ if self.extra_num_query_token > 0:
116
+ query_tokens = torch.cat([self.query_tokens, self.extra_query_tokens], dim=1)
117
+ query_tokens = query_tokens.expand(image_embeds.shape[0], -1, -1)
118
+ if instruction is not None:
119
+ text_Qformer = self.qformer_tokenizer(
120
+ instruction,
121
+ padding='longest',
122
+ truncation=True,
123
+ max_length=512,
124
+ return_tensors="pt",
125
+ ).to(image_embeds.device)
126
+ query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(image_embeds.device)
127
+ Qformer_atts = torch.cat([query_atts, text_Qformer.attention_mask], dim=1)
128
+ query_output = self.qformer.bert(
129
+ text_Qformer.input_ids,
130
+ attention_mask=Qformer_atts,
131
+ query_embeds=query_tokens,
132
+ encoder_hidden_states=image_embeds,
133
+ encoder_attention_mask=image_atts,
134
+ return_dict=True,
135
+ )
136
+ else:
137
+ query_output = self.qformer.bert(
138
+ query_embeds=query_tokens,
139
+ encoder_hidden_states=image_embeds,
140
+ encoder_attention_mask=image_atts,
141
+ return_dict=True,
142
+ )
143
+
144
+ return query_output.last_hidden_state[:, :query_tokens.size(1), :]
145
+
146
+
147
+ def generate_caption(
148
+ self,
149
+ input_ids,
150
+ attention_mask,
151
+ image_idx = None,
152
+ video_idx = None,
153
+ image: Optional[torch.Tensor] = None,
154
+ video: Optional[torch.Tensor] = None,
155
+ num_beams=1,
156
+ max_new_tokens=200,
157
+ do_sample=True,
158
+ top_p=0.9,
159
+ top_k=None,
160
+ temperature=1.0,
161
+ length_penalty=1,
162
+ repetition_penalty=1.0,
163
+ ):
164
+ text_embeds = self.pad_text_embeds(input_ids=input_ids, image=image, video=video, image_idx=image_idx, video_idx=video_idx)
165
+ outputs = self.lm.generate(
166
+ inputs_embeds=text_embeds,
167
+ attention_mask=attention_mask,
168
+ num_beams=num_beams,
169
+ max_new_tokens=max_new_tokens,
170
+ do_sample=do_sample,
171
+ min_length=1,
172
+ top_p=top_p,
173
+ top_k=top_k,
174
+ temperature=temperature,
175
+ length_penalty=length_penalty,
176
+ repetition_penalty=repetition_penalty,
177
+ )
178
+
179
+ return outputs
special_tokens_map.json CHANGED
@@ -13,7 +13,6 @@
13
  "rstrip": false,
14
  "single_word": false
15
  },
16
- "pad_token": "<unk>",
17
  "unk_token": {
18
  "content": "<unk>",
19
  "lstrip": false,
 
13
  "rstrip": false,
14
  "single_word": false
15
  },
 
16
  "unk_token": {
17
  "content": "<unk>",
18
  "lstrip": false,
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json CHANGED
@@ -6178,10 +6178,10 @@
6178
  "eos_token": "</s>",
6179
  "legacy": false,
6180
  "model_max_length": 1000000000000000019884624838656,
6181
- "pad_token": "<unk>",
6182
  "sp_model_kwargs": {},
6183
  "spaces_between_special_tokens": false,
6184
- "tokenizer_class": "MultimodalLlamaTokenizer",
6185
  "unk_token": "<unk>",
6186
  "use_default_system_prompt": false
6187
  }
 
6178
  "eos_token": "</s>",
6179
  "legacy": false,
6180
  "model_max_length": 1000000000000000019884624838656,
6181
+ "pad_token": null,
6182
  "sp_model_kwargs": {},
6183
  "spaces_between_special_tokens": false,
6184
+ "tokenizer_class": "LlamaTokenizer",
6185
  "unk_token": "<unk>",
6186
  "use_default_system_prompt": false
6187
  }