ynhe commited on
Commit
16dc4f2
·
1 Parent(s): 371f0d2
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. added_tokens.json +9 -0
  2. config.json +60 -0
  3. model-00001-of-00004.safetensors +3 -0
  4. model-00002-of-00004.safetensors +3 -0
  5. model-00003-of-00004.safetensors +3 -0
  6. model-00004-of-00004.safetensors +3 -0
  7. model.safetensors.index.json +0 -0
  8. model_config.py +24 -0
  9. modeling_base.py +387 -0
  10. modeling_qformer.py +1264 -0
  11. modeling_special_token.py +27 -0
  12. modeling_videochate.py +681 -0
  13. modeling_vit.py +487 -0
  14. special_tokens_map.json +24 -0
  15. third_party/__init__.py +2 -0
  16. third_party/cgdetr/cg_detr/__init__.py +0 -0
  17. third_party/cgdetr/cg_detr/__pycache__/__init__.cpython-310.pyc +0 -0
  18. third_party/cgdetr/cg_detr/__pycache__/attention.cpython-310.pyc +0 -0
  19. third_party/cgdetr/cg_detr/__pycache__/crossattention.cpython-310.pyc +0 -0
  20. third_party/cgdetr/cg_detr/__pycache__/matcher.cpython-310.pyc +0 -0
  21. third_party/cgdetr/cg_detr/__pycache__/misc.cpython-310.pyc +0 -0
  22. third_party/cgdetr/cg_detr/__pycache__/model.cpython-310.pyc +0 -0
  23. third_party/cgdetr/cg_detr/__pycache__/position_encoding.cpython-310.pyc +0 -0
  24. third_party/cgdetr/cg_detr/__pycache__/span_utils.cpython-310.pyc +0 -0
  25. third_party/cgdetr/cg_detr/__pycache__/transformer.cpython-310.pyc +0 -0
  26. third_party/cgdetr/cg_detr/attention.py +394 -0
  27. third_party/cgdetr/cg_detr/config.py +261 -0
  28. third_party/cgdetr/cg_detr/crossattention.py +396 -0
  29. third_party/cgdetr/cg_detr/inference.py +480 -0
  30. third_party/cgdetr/cg_detr/matcher.py +109 -0
  31. third_party/cgdetr/cg_detr/misc.py +21 -0
  32. third_party/cgdetr/cg_detr/model.py +1178 -0
  33. third_party/cgdetr/cg_detr/position_encoding.py +116 -0
  34. third_party/cgdetr/cg_detr/postprocessing_cg_detr.py +95 -0
  35. third_party/cgdetr/cg_detr/scripts/charades_sta/inference.sh +8 -0
  36. third_party/cgdetr/cg_detr/scripts/charades_sta/train.sh +95 -0
  37. third_party/cgdetr/cg_detr/scripts/inference.sh +11 -0
  38. third_party/cgdetr/cg_detr/scripts/train.sh +76 -0
  39. third_party/cgdetr/cg_detr/span_utils.py +127 -0
  40. third_party/cgdetr/cg_detr/start_end_dataset.py +383 -0
  41. third_party/cgdetr/cg_detr/text_encoder.py +53 -0
  42. third_party/cgdetr/cg_detr/train.py +283 -0
  43. third_party/cgdetr/cg_detr/transformer.py +871 -0
  44. third_party/cgdetr/data/LICENSE +437 -0
  45. third_party/cgdetr/data/README.md +24 -0
  46. third_party/cgdetr/standalone_eval/README.md +54 -0
  47. third_party/cgdetr/standalone_eval/eval.py +361 -0
  48. third_party/cgdetr/standalone_eval/eval_sample.sh +10 -0
  49. third_party/cgdetr/standalone_eval/utils.py +209 -0
  50. third_party/cgdetr/utils/basic_utils.py +221 -0
added_tokens.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "<box_begin>": 32000,
3
+ "<boxes>": 32003,
4
+ "<temp>": 32002,
5
+ "<time_begin>": 32001,
6
+ "<track_begin>": 32004,
7
+ "<track_box>": 32006,
8
+ "<tracking>": 32005
9
+ }
config.json ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "",
3
+ "architectures": [
4
+ "MultiModalLLM_PT"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "model_config.VideoChatEConfig",
8
+ "AutoModel": "modeling_videochate.MultiModalLLM_PT"
9
+ },
10
+ "model_config": {
11
+ "bridge": {
12
+ "extra_num_query_token": 64,
13
+ "name": "qformer",
14
+ "num_query_token": 32,
15
+ "qformer_attention_probs_dropout_prob": 0.1,
16
+ "qformer_drop_path_rate": 0.2,
17
+ "qformer_hidden_dropout_prob": 0.1
18
+ },
19
+ "freeze_bridge": false,
20
+ "freeze_llm": false,
21
+ "freeze_vision_encoder": false,
22
+ "llm": {
23
+ "lora_alpha": 32,
24
+ "lora_dropout": 0.1,
25
+ "lora_r": 16,
26
+ "name": "mistral_7b",
27
+ "pretrained_llm_path": "mistralai/Mistral-7B-Instruct-v0.3",
28
+ "use_lora": true,
29
+ "hidden_size": 4096
30
+ },
31
+ "loss": {
32
+ "use_vision_regression_loss": false
33
+ },
34
+ "pretrained_paths": {},
35
+
36
+ "vision_encoder": {
37
+ "name":"vit_l14",
38
+ "img_size":224,
39
+ "patch_size":16,
40
+ "d_model":1024,
41
+ "encoder_embed_dim":1024,
42
+ "encoder_depth":24,
43
+ "encoder_num_heads":16,
44
+ "drop_path_rate": 0.0,
45
+ "num_frames":16,
46
+ "tubelet_size":1,
47
+ "use_checkpoint":false,
48
+ "checkpoint_num":0,
49
+ "return_index":-2,
50
+ "vit_add_ln":true,
51
+ "pretrained": null
52
+ }
53
+ },
54
+ "torch_dtype": "float32",
55
+ "transformers_version": "4.38.0",
56
+ "use_flash_attention": true,
57
+ "use_cache": true,
58
+ "build_decoder":true,
59
+ "hidden_size": 4096
60
+ }
model-00001-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c8640081ef34803134daaa9c4a69693b680bce89b3dcc66ed22df3a26d97a330
3
+ size 4995778624
model-00002-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8907fc1447bbaac12ba0e687cd25a19810d2717fe7e73dad57f70f13532c086c
3
+ size 4945367960
model-00003-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5640bcb11fc22f6a7604176b29807f56a856126ed62a97a9e99585eb313e4f93
3
+ size 4945392936
model-00004-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d0360c7932a26ce651a78f5dc26e7b6735ff0bc019fb8adce231fea4375d4fe8
3
+ size 1316534788
model.safetensors.index.json ADDED
The diff for this file is too large to render. See raw diff
 
model_config.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import re, ast
3
+ from transformers import AutoConfig, LlamaConfig
4
+ from transformers.configuration_utils import PretrainedConfig
5
+ from transformers.utils import logging
6
+
7
+ from easydict import EasyDict as MyEasyDict
8
+ from importlib import import_module
9
+ import os.path as osp
10
+ import argparse
11
+ import json
12
+ from copy import deepcopy
13
+ import sys
14
+
15
+
16
+ class VideoChatEConfig(PretrainedConfig):
17
+ model_type = 'VideoChatE'
18
+
19
+ def __init__(
20
+ self,
21
+ model_config=None,
22
+ **kwargs):
23
+ super().__init__(**kwargs)
24
+ self.model_config = MyEasyDict(model_config)
modeling_base.py ADDED
@@ -0,0 +1,387 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import logging
3
+ import torch
4
+ import torch.utils.checkpoint
5
+ from torch import nn
6
+ from torch.nn import MSELoss
7
+ from transformers.modeling_outputs import (
8
+ CausalLMOutputWithPast,
9
+ )
10
+ from typing import List, Optional, Tuple, Union
11
+ from transformers import LlamaForCausalLM
12
+
13
+ from torch.cuda.amp import autocast as autocast
14
+
15
+ from .modeling_vit import build_vit
16
+ from .modeling_qformer import build_qformer
17
+ from .model_config import VideoChatEConfig
18
+ logger = logging.getLogger(__name__)
19
+
20
+ from transformers import LlamaTokenizer,AutoTokenizer,AutoModel,AutoModelForCausalLM,AutoProcessor
21
+ from transformers import AutoConfig, PreTrainedModel
22
+
23
+ import os
24
+ import sys
25
+
26
+
27
+ try:
28
+ from third_party.sam2.build_sam import build_sam2_video_predictor
29
+ from third_party.cgdetr.cg_detr.model import build_cgdetr_model
30
+ except:
31
+ print("can not import sam2 and cg-detr, install them first.")
32
+
33
+ DEFAULT_IMG_TOKEN = "[IMG]"
34
+ DEFAULT_IMG_END_TOKEN = "[/IMG]"
35
+
36
+ DEFAULT_IMAGE_TOKEN = "<image>"
37
+ DEFAULT_VIDEO_TOKEN = "[VIDEO]"
38
+
39
+ IMG_TOKEN = "[<IMG_PLH>]"
40
+ VID_TOKEN = "[<VID_PLH>]"
41
+
42
+ BOX_START = '<box_begin>'
43
+ # BOX_END = '<box_end>'
44
+ ATBOXES_PLACEHOLDER = '<box_begin><boxes>'
45
+ # ATBOXES_PLACEHOLDER = '<box_begin>'
46
+ BOXES_PLACEHOLDER = '<boxes>'
47
+ EXPR_PLACEHOLDER = '<expr>'
48
+ QUESTION_PLACEHOLDER = '<question>'
49
+ TIME_START = '<time_begin>'
50
+ # TIME_END = '<time_end>'
51
+ TIME_PLACEHOLDER = '<temp>'
52
+ ATTEMP_PLACEHOLDER = TIME_START + TIME_PLACEHOLDER
53
+ # ATTEMP_PLACEHOLDER = TIME_START
54
+ TRACK_START='<track_begin>'
55
+ TRACK_PLACEHOLDER = '<tracking>'
56
+ TRACK_START_BOX = '<track_box>'
57
+ ATTRACK_PLACEHOLDER = TRACK_START + TRACK_PLACEHOLDER
58
+ need_template_list = ['REC', 'flickr', 'tracking', 'tracking2', 'tracking3', 'tracking4']
59
+
60
+ load_image_list = ['image', 'REC', 'flickr']
61
+ load_video_list = ['video', 'TVG', 'tracking', 'tracking2','tracking3', 'tracking4', 'TVG+HL']
62
+ special_tokens = [BOX_START, TIME_START, TIME_PLACEHOLDER, BOXES_PLACEHOLDER, TRACK_START, TRACK_PLACEHOLDER, TRACK_START_BOX]
63
+
64
+ def disabled_train(self, mode=True):
65
+ """Overwrite model.train with this function to make sure train/eval mode
66
+ does not change anymore."""
67
+ return self
68
+
69
+
70
+ def freeze_module(module):
71
+ for _, param in module.named_parameters():
72
+ param.requires_grad = False
73
+ module = module.eval()
74
+ module.train = disabled_train
75
+ return module
76
+
77
+
78
+ class LLMConfig(AutoConfig):
79
+ model_type = "20b"
80
+
81
+
82
+ class BaseMLLM(PreTrainedModel):
83
+ config_class = VideoChatEConfig
84
+ def __init__(self, config,_tokenizer=None):
85
+ # super().__init__(config)
86
+ self.model_config = config.model_config
87
+ self.tokenizer = _tokenizer
88
+
89
+ config.cg_opt = None
90
+ config.model_config = None
91
+ config.model_tokenizer = None
92
+ super().__init__(config)
93
+ self.build_vision_encoder()
94
+ self.build_llm()
95
+ self.build_bridge()
96
+ self.build_loss()
97
+
98
+ self.load_pretrained_weights()
99
+ try:
100
+ if config.build_decoder:
101
+ self.cg_opt = config.cg_opt
102
+ self.build_bbox_decoder()
103
+ self.build_sam()
104
+ self.build_CGDETR()
105
+ except:
106
+ print("please install cgdetr and sam2 first")
107
+ logger.info(f'Length of tokenizer and resize embedding: {len(self.tokenizer)}')
108
+
109
+
110
+ def build_vision_encoder(self):
111
+ if 'internvideo2' in self.model_config.vision_encoder.name.lower():
112
+ encoder_name = self.model_config.vision_encoder.name
113
+ logger.info(f"Build vision_encoder: {encoder_name}")
114
+ if encoder_name == 'internvideo2-1B':
115
+ self.vision_encoder = pretrain_internvideo2_giant_patch14_224_clean(self.model_config)
116
+
117
+ else:
118
+ raise ValueError(f"Not implemented: {encoder_name}")
119
+ elif 'vit' in self.model_config.vision_encoder.name.lower():
120
+ self.vision_encoder = build_vit(self.model_config)
121
+ else:
122
+ raise NotImplementedError(self.model_config.vision_encoder.name)
123
+
124
+ if self.model_config.vision_encoder.vit_add_ln:
125
+ self.vision_layernorm = nn.LayerNorm(self.model_config.vision_encoder.encoder_embed_dim, eps=1e-12)
126
+ else:
127
+ self.vision_layernorm = nn.Identity()
128
+
129
+ self.freeze_vision_encoder = self.model_config.get("freeze_vision_encoder", False)
130
+
131
+ if self.freeze_vision_encoder:
132
+ logger.info("freeze vision encoder")
133
+ freeze_module(self.vision_encoder)
134
+ freeze_module(self.vision_layernorm)
135
+
136
+ def build_CGDETR(self):
137
+ self.cg_model, self.cg_criterion = build_cgdetr_model()
138
+
139
+ def build_bridge(self):
140
+ # ViT to LM: 1792 -> 6656 NOTE 768 is qformer dim
141
+ self.project_up = nn.Linear(768, self.lm.config.hidden_size) # whether bias is needed?
142
+ # LM to ViT: 6656 -> 1792
143
+ self.project_down = nn.Linear(self.lm.config.hidden_size, 768)
144
+
145
+ if 'qformer' in self.model_config.bridge.name.lower():
146
+ from transformers import BertTokenizer
147
+ self.qformer_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", truncation_side="left")
148
+ self.qformer_tokenizer.add_special_tokens({"bos_token": "[DEC]"})
149
+ self.qformer_tokenizer.padding_side = "left"
150
+ if self.model_config.bridge.name == 'qformer':
151
+ self.qformer, self.query_tokens = build_qformer(
152
+ self.model_config.bridge.num_query_token, self.model_config.vision_encoder.encoder_embed_dim,
153
+ qformer_hidden_dropout_prob=self.model_config.bridge.qformer_hidden_dropout_prob,
154
+ qformer_attention_probs_dropout_prob=self.model_config.bridge.qformer_attention_probs_dropout_prob,
155
+ qformer_drop_path_rate=self.model_config.bridge.qformer_drop_path_rate,
156
+ )
157
+ elif self.model_config.bridge.name == 'causal_qformer':
158
+ self.qformer, self.query_tokens = build_causal_qformer(
159
+ self.model_config.bridge.num_query_token, self.model_config.vision_encoder.encoder_embed_dim,
160
+ qformer_hidden_dropout_prob=self.model_config.bridge.qformer_hidden_dropout_prob,
161
+ qformer_attention_probs_dropout_prob=self.model_config.bridge.qformer_attention_probs_dropout_prob
162
+ )
163
+ self.qformer.resize_token_embeddings(len(self.qformer_tokenizer))
164
+ self.qformer.cls = None
165
+ self.extra_num_query_token = self.model_config.bridge.extra_num_query_token
166
+ if self.model_config.bridge.extra_num_query_token > 0:
167
+ logger.info(f"Add extra {self.model_config.bridge.extra_num_query_token} tokens in QFormer")
168
+ self.extra_query_tokens = nn.Parameter(
169
+ torch.zeros(1, self.model_config.bridge.extra_num_query_token, self.query_tokens.shape[-1])
170
+ )
171
+
172
+ self.freeze_bridge = self.model_config.get("freeze_bridge", False)
173
+ if self.freeze_bridge:
174
+ logger.info("freeze bridge")
175
+ freeze_module(self.qformer)
176
+ self.query_tokens.requires_grad = False
177
+
178
+ def build_llm(self):
179
+ self.lm_name = self.model_config.llm.name
180
+ if self.model_config.llm.name == "vicuna1.5_7b":
181
+ self.lm = LlamaForCausalLM.from_pretrained(self.model_config.llm.pretrained_llm_path)
182
+ self.lm.gradient_checkpointing = self.model_config.llm.get("use_llama_gradient_checkpointing", True)
183
+ elif self.model_config.llm.name == 'mistral_7b':
184
+ from transformers import AutoModelForCausalLM
185
+
186
+ config = AutoConfig.from_pretrained(
187
+ self.model_config.llm.pretrained_llm_path,
188
+ torch_dtype=torch.bfloat16,
189
+ # attn_implementation="flash_attention_2",
190
+ )
191
+ self.lm = AutoModelForCausalLM.from_config(config)
192
+ elif self.model_config.llm.name == 'internlm_20b':
193
+ from transformers import AutoModelForCausalLM
194
+ self.lm = AutoModelForCausalLM.from_pretrained(
195
+ self.model_config.llm.pretrained_llm_path,
196
+ torch_dtype=torch.bfloat16,
197
+ trust_remote_code=True,
198
+ )
199
+ self.lm.gradient_checkpointing = True
200
+ self.lm._set_gradient_checkpointing()
201
+ else:
202
+ raise NotImplementedError(self.model_config.llm.name)
203
+
204
+ num_new_tokens = len(special_tokens)
205
+ self.lm.resize_token_embeddings(len(self.tokenizer))
206
+
207
+ input_embeddings = self.lm.get_input_embeddings().weight.data
208
+ output_embeddings = self.lm.get_output_embeddings().weight.data
209
+
210
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
211
+ dim=0, keepdim=True)
212
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
213
+ dim=0, keepdim=True)
214
+
215
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
216
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
217
+
218
+ self.model_config.token_at_ids = self.tokenizer.convert_tokens_to_ids([BOX_START])[0]
219
+ self.freeze_llm = self.model_config.get("freeze_llm", True)
220
+ logger.info(f'freeze_llm: {self.freeze_llm}')
221
+ if self.freeze_llm:
222
+ logger.info("freeze llm")
223
+ freeze_module(self.lm)
224
+
225
+ if self.model_config.llm.use_lora:
226
+ self.use_lora = True
227
+ from peft import get_peft_model, LoraConfig, TaskType
228
+ logger.info("Use lora")
229
+ if self.model_config.llm.name == 'internlm_20b':
230
+ peft_config = LoraConfig(
231
+ task_type=TaskType.CAUSAL_LM, inference_mode=False,
232
+ r=self.model_config.llm.lora_r, lora_alpha=self.model_config.llm.lora_alpha, lora_dropout=self.model_config.llm.lora_dropout,
233
+ target_modules=['wqkv', 'wo', 'w1', 'w2', 'w3', 'output']
234
+ )
235
+ else:
236
+ peft_config = LoraConfig(
237
+ task_type=TaskType.CAUSAL_LM, inference_mode=False,
238
+ r=self.model_config.llm.lora_r, lora_alpha=self.model_config.llm.lora_alpha, lora_dropout=self.model_config.llm.lora_dropout,
239
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
240
+ "gate_proj", "up_proj", "down_proj", "lm_head"]
241
+ )
242
+
243
+ self.lm = get_peft_model(self.lm, peft_config)
244
+ self.lm.enable_input_require_grads()
245
+ self.lm.print_trainable_parameters()
246
+
247
+ if self.model_config.get("freeze_lora", False):
248
+ logger.info("freeze lora")
249
+ freeze_module(self.lm)
250
+ self.lm.print_trainable_parameters()
251
+
252
+ else:
253
+ self.use_lora = False
254
+
255
+ def add_lora(self):
256
+ if self.model_config.llm.use_lora:
257
+ self.use_lora = True
258
+ from peft import get_peft_model, LoraConfig, TaskType
259
+ logger.info("Use lora")
260
+ if self.model_config.llm.name == 'internlm_20b':
261
+ peft_config = LoraConfig(
262
+ task_type=TaskType.CAUSAL_LM, inference_mode=False,
263
+ r=self.model_config.llm.lora_r, lora_alpha=self.model_config.llm.lora_alpha, lora_dropout=self.model_config.llm.lora_dropout,
264
+ target_modules=['wqkv', 'wo', 'w1', 'w2', 'w3', 'output']
265
+ )
266
+ else:
267
+ peft_config = LoraConfig(
268
+ task_type=TaskType.CAUSAL_LM, inference_mode=False,
269
+ r=self.model_config.llm.lora_r, lora_alpha=self.model_config.llm.lora_alpha, lora_dropout=self.model_config.llm.lora_dropout,
270
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
271
+ "gate_proj", "up_proj", "down_proj", "lm_head"]
272
+ )
273
+
274
+ self.lm = get_peft_model(self.lm, peft_config)
275
+ self.lm.enable_input_require_grads()
276
+ self.lm.print_trainable_parameters()
277
+
278
+ if self.model_config.get("freeze_lora", False):
279
+ logger.info("freeze lora")
280
+ freeze_module(self.lm)
281
+ self.lm.print_trainable_parameters()
282
+
283
+ else:
284
+ self.use_lora = False
285
+
286
+ def add_tokens(self):
287
+ num_new_tokens = len(special_tokens)
288
+ self.lm.resize_token_embeddings(len(self.tokenizer))
289
+
290
+ input_embeddings = self.lm.get_input_embeddings().weight.data
291
+ output_embeddings = self.lm.get_output_embeddings().weight.data
292
+
293
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
294
+ dim=0, keepdim=True)
295
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
296
+ dim=0, keepdim=True)
297
+ print(self.lm.get_input_embeddings().weight.data.shape)
298
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
299
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
300
+
301
+ self.model_config.token_at_ids = self.tokenizer.convert_tokens_to_ids([BOX_START])[0]
302
+
303
+ def build_loss(self):
304
+ self.use_vision_regression_loss = self.model_config.loss.get("use_vision_regression_loss", False)
305
+ self.use_bbox_loss = self.model_config.loss.get("add_bbox_loss", False)
306
+ self.use_mask_loss = self.model_config.loss.get("use_mask_loss", False)
307
+ self.use_temporal_loss = self.model_config.loss.get('use_temporal_loss', False)
308
+ if self.use_vision_regression_loss:
309
+ self.image_loss_fct = MSELoss()
310
+
311
+
312
+ def load_pretrained_weights(self):
313
+ if self.model_config.pretrained_paths.get('pretrained_vit_qformer_path', None):
314
+ if 'safetensor' in self.model_config.pretrained_paths.pretrained_vit_qformer_path:
315
+ from safetensors import safe_open
316
+ from safetensors.torch import save_file
317
+ state_dict = {}
318
+ with safe_open(self.model_config.pretrained_paths.pretrained_vit_qformer_path, framework="pt", device="cpu") as f:
319
+ for key in f.keys():
320
+ state_dict[key] = f.get_tensor(key)
321
+ else:
322
+ state_dict = torch.load(self.model_config.pretrained_paths.pretrained_vit_qformer_path, map_location="cpu")
323
+ if "model" in state_dict.keys():
324
+ state_dict = state_dict["model"]
325
+ elif "module" in state_dict.keys():
326
+ state_dict = state_dict["module"] # for deepspeed
327
+ self.check_temp_emb(state_dict)
328
+ msg = self.load_state_dict(state_dict, strict=False)
329
+ print('Loading vit: ', msg)
330
+ logger.info(f"Load ViT and QFormer from {self.model_config.pretrained_paths.pretrained_vit_qformer_path}: {msg}")
331
+
332
+ if self.model_config.pretrained_paths.get('pretrained_videochat2', None):
333
+ state_dict = torch.load(self.model_config.pretrained_paths.pretrained_videochat2, map_location="cpu")
334
+
335
+ new_state_dict = {}
336
+ for k in state_dict.keys():
337
+ if 'bert.embeddings' not in k:
338
+ new_state_dict[k] = state_dict[k]
339
+ state_dict = new_state_dict
340
+ # self.check_temp_emb(state_dict)
341
+ msg = self.load_state_dict(state_dict, strict=False)
342
+ print('Loading videochat2: ', msg)
343
+
344
+
345
+ def check_temp_emb(self, state_dict):
346
+ old_num_frames = self.model_config.vision_encoder.get('origin_num_frames', None)
347
+ new_num_frames = self.model_config.vision_encoder.num_frames
348
+ if old_num_frames is not None and old_num_frames != new_num_frames:
349
+ logger.info(f"interpolate_pos_embed_internvideo2 to {new_num_frames} (origin_num_frames={old_num_frames})!!!")
350
+ a = len(state_dict)
351
+ interpolate_pos_embed_internvideo2_new(state_dict, self.vision_encoder, orig_t_size=4)
352
+ assert a == len(state_dict), state_dict.keys()
353
+
354
+ def build_bbox_decoder(self):
355
+ self.loc_encoder = nn.Sequential(
356
+ nn.Linear(4, self.model_config.llm.hidden_size // 2, dtype=torch.bfloat16),
357
+ nn.ReLU(),
358
+ nn.Linear(self.model_config.llm.hidden_size // 2, self.model_config.llm.hidden_size, dtype=torch.bfloat16),
359
+ )
360
+
361
+ self.loc_decoder = nn.Sequential(
362
+ nn.Linear(self.model_config.llm.hidden_size, self.model_config.llm.hidden_size // 2, dtype=torch.bfloat16),
363
+ nn.ReLU(),
364
+ nn.Linear(self.model_config.llm.hidden_size // 2, 4, dtype=torch.bfloat16)
365
+ )
366
+ self._initialize_bbox_weights()
367
+
368
+ def _initialize_bbox_weights(self):
369
+ return
370
+
371
+ def build_sam(self):
372
+ sam2_checkpoint = "/cpfs01/user/heyinan/checkpoints/sam2_hiera_large.pt"
373
+ model_cfg = "sam2_hiera_l.yaml"
374
+ predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device=self.lm.device)
375
+
376
+ self.sam = predictor
377
+ freeze_module(self.sam)
378
+
379
+
380
+ @property
381
+ def dtype(self):
382
+ return self.lm.dtype
383
+
384
+
385
+ @property
386
+ def device(self):
387
+ return self.lm.device
modeling_qformer.py ADDED
@@ -0,0 +1,1264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ * Copyright (c) 2023, salesforce.com, inc.
3
+ * All rights reserved.
4
+ * SPDX-License-Identifier: BSD-3-Clause
5
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ * By Junnan Li
7
+ * Based on huggingface code base
8
+ * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert
9
+ """
10
+ import logging
11
+ import math
12
+ import os
13
+ import warnings
14
+ from dataclasses import dataclass
15
+ from typing import Optional, Tuple, Dict, Any
16
+
17
+ import torch
18
+ from torch import Tensor, device, dtype, nn
19
+ import torch.utils.checkpoint
20
+ from torch import nn
21
+ from torch.nn import CrossEntropyLoss
22
+ import torch.nn.functional as F
23
+
24
+ from timm.models.layers import drop_path
25
+ from transformers.activations import ACT2FN
26
+ from transformers.file_utils import (
27
+ ModelOutput,
28
+ )
29
+ from transformers.modeling_outputs import (
30
+ BaseModelOutputWithPastAndCrossAttentions,
31
+ BaseModelOutputWithPoolingAndCrossAttentions,
32
+ CausalLMOutputWithCrossAttentions,
33
+ MaskedLMOutput,
34
+ MultipleChoiceModelOutput,
35
+ NextSentencePredictorOutput,
36
+ QuestionAnsweringModelOutput,
37
+ SequenceClassifierOutput,
38
+ TokenClassifierOutput,
39
+ )
40
+ from transformers.modeling_utils import (
41
+ PreTrainedModel,
42
+ apply_chunking_to_forward,
43
+ find_pruneable_heads_and_indices,
44
+ prune_linear_layer,
45
+ )
46
+ from transformers.models.bert.configuration_bert import BertConfig
47
+
48
+ import logging
49
+ logger = logging.getLogger(__name__)
50
+
51
+
52
+ class BertEmbeddings(nn.Module):
53
+ """Construct the embeddings from word and position embeddings."""
54
+
55
+ def __init__(self, config):
56
+ super().__init__()
57
+ self.word_embeddings = nn.Embedding(
58
+ config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id
59
+ )
60
+ self.position_embeddings = nn.Embedding(
61
+ config.max_position_embeddings, config.hidden_size
62
+ )
63
+
64
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
65
+ # any TensorFlow checkpoint file
66
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
67
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
68
+
69
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
70
+ self.register_buffer(
71
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))
72
+ )
73
+ self.position_embedding_type = getattr(
74
+ config, "position_embedding_type", "absolute"
75
+ )
76
+
77
+ self.config = config
78
+
79
+ def forward(
80
+ self,
81
+ input_ids=None,
82
+ position_ids=None,
83
+ query_embeds=None,
84
+ past_key_values_length=0,
85
+ ):
86
+ if input_ids is not None:
87
+ seq_length = input_ids.size()[1]
88
+ else:
89
+ seq_length = 0
90
+
91
+ if position_ids is None:
92
+ position_ids = self.position_ids[
93
+ :, past_key_values_length : seq_length + past_key_values_length
94
+ ].clone()
95
+
96
+ if input_ids is not None:
97
+ embeddings = self.word_embeddings(input_ids)
98
+ if self.position_embedding_type == "absolute":
99
+ position_embeddings = self.position_embeddings(position_ids)
100
+ embeddings = embeddings + position_embeddings
101
+
102
+ if query_embeds is not None:
103
+ embeddings = torch.cat((query_embeds, embeddings), dim=1)
104
+ else:
105
+ embeddings = query_embeds
106
+
107
+ embeddings = self.LayerNorm(embeddings)
108
+ embeddings = self.dropout(embeddings)
109
+ return embeddings
110
+
111
+
112
+ class BertSelfAttention(nn.Module):
113
+ def __init__(self, config, is_cross_attention):
114
+ super().__init__()
115
+ self.config = config
116
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
117
+ config, "embedding_size"
118
+ ):
119
+ raise ValueError(
120
+ "The hidden size (%d) is not a multiple of the number of attention "
121
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads)
122
+ )
123
+
124
+ self.num_attention_heads = config.num_attention_heads
125
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
126
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
127
+
128
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
129
+ if is_cross_attention:
130
+ self.key = nn.Linear(config.encoder_width, self.all_head_size)
131
+ self.value = nn.Linear(config.encoder_width, self.all_head_size)
132
+ else:
133
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
134
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
135
+
136
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
137
+ self.position_embedding_type = getattr(
138
+ config, "position_embedding_type", "absolute"
139
+ )
140
+ if (
141
+ self.position_embedding_type == "relative_key"
142
+ or self.position_embedding_type == "relative_key_query"
143
+ ):
144
+ self.max_position_embeddings = config.max_position_embeddings
145
+ self.distance_embedding = nn.Embedding(
146
+ 2 * config.max_position_embeddings - 1, self.attention_head_size
147
+ )
148
+ self.save_attention = False
149
+
150
+ def save_attn_gradients(self, attn_gradients):
151
+ self.attn_gradients = attn_gradients
152
+
153
+ def get_attn_gradients(self):
154
+ return self.attn_gradients
155
+
156
+ def save_attention_map(self, attention_map):
157
+ self.attention_map = attention_map
158
+
159
+ def get_attention_map(self):
160
+ return self.attention_map
161
+
162
+ def transpose_for_scores(self, x):
163
+ new_x_shape = x.size()[:-1] + (
164
+ self.num_attention_heads,
165
+ self.attention_head_size,
166
+ )
167
+ x = x.view(*new_x_shape)
168
+ return x.permute(0, 2, 1, 3)
169
+
170
+ def forward(
171
+ self,
172
+ hidden_states,
173
+ attention_mask=None,
174
+ head_mask=None,
175
+ encoder_hidden_states=None,
176
+ encoder_attention_mask=None,
177
+ past_key_value=None,
178
+ output_attentions=False,
179
+ ):
180
+
181
+ # If this is instantiated as a cross-attention module, the keys
182
+ # and values come from an encoder; the attention mask needs to be
183
+ # such that the encoder's padding tokens are not attended to.
184
+ is_cross_attention = encoder_hidden_states is not None
185
+
186
+ if is_cross_attention:
187
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
188
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
189
+ attention_mask = encoder_attention_mask
190
+ elif past_key_value is not None:
191
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
192
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
193
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
194
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
195
+ else:
196
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
197
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
198
+
199
+ mixed_query_layer = self.query(hidden_states)
200
+
201
+ query_layer = self.transpose_for_scores(mixed_query_layer)
202
+
203
+ past_key_value = (key_layer, value_layer)
204
+
205
+ # Take the dot product between "query" and "key" to get the raw attention scores.
206
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
207
+
208
+ if (
209
+ self.position_embedding_type == "relative_key"
210
+ or self.position_embedding_type == "relative_key_query"
211
+ ):
212
+ seq_length = hidden_states.size()[1]
213
+ position_ids_l = torch.arange(
214
+ seq_length, dtype=torch.long, device=hidden_states.device
215
+ ).view(-1, 1)
216
+ position_ids_r = torch.arange(
217
+ seq_length, dtype=torch.long, device=hidden_states.device
218
+ ).view(1, -1)
219
+ distance = position_ids_l - position_ids_r
220
+ positional_embedding = self.distance_embedding(
221
+ distance + self.max_position_embeddings - 1
222
+ )
223
+ positional_embedding = positional_embedding.to(
224
+ dtype=query_layer.dtype
225
+ ) # fp16 compatibility
226
+
227
+ if self.position_embedding_type == "relative_key":
228
+ relative_position_scores = torch.einsum(
229
+ "bhld,lrd->bhlr", query_layer, positional_embedding
230
+ )
231
+ attention_scores = attention_scores + relative_position_scores
232
+ elif self.position_embedding_type == "relative_key_query":
233
+ relative_position_scores_query = torch.einsum(
234
+ "bhld,lrd->bhlr", query_layer, positional_embedding
235
+ )
236
+ relative_position_scores_key = torch.einsum(
237
+ "bhrd,lrd->bhlr", key_layer, positional_embedding
238
+ )
239
+ attention_scores = (
240
+ attention_scores
241
+ + relative_position_scores_query
242
+ + relative_position_scores_key
243
+ )
244
+
245
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
246
+ if attention_mask is not None:
247
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
248
+ attention_scores = attention_scores + attention_mask
249
+
250
+ # Normalize the attention scores to probabilities.
251
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
252
+
253
+ if is_cross_attention and self.save_attention:
254
+ self.save_attention_map(attention_probs)
255
+ attention_probs.register_hook(self.save_attn_gradients)
256
+
257
+ # This is actually dropping out entire tokens to attend to, which might
258
+ # seem a bit unusual, but is taken from the original Transformer paper.
259
+ attention_probs_dropped = self.dropout(attention_probs)
260
+
261
+ # Mask heads if we want to
262
+ if head_mask is not None:
263
+ attention_probs_dropped = attention_probs_dropped * head_mask
264
+
265
+ context_layer = torch.matmul(attention_probs_dropped, value_layer)
266
+
267
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
268
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
269
+ context_layer = context_layer.view(*new_context_layer_shape)
270
+
271
+ outputs = (
272
+ (context_layer, attention_probs) if output_attentions else (context_layer,)
273
+ )
274
+
275
+ outputs = outputs + (past_key_value,)
276
+ return outputs
277
+
278
+
279
+ class DropPath(nn.Module):
280
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
281
+ """
282
+ def __init__(self, drop_prob=None):
283
+ super(DropPath, self).__init__()
284
+ self.drop_prob = drop_prob
285
+
286
+ def forward(self, x):
287
+ return drop_path(x, self.drop_prob, self.training)
288
+
289
+ def extra_repr(self) -> str:
290
+ return 'p={}'.format(self.drop_prob)
291
+
292
+
293
+ class BertSelfOutput(nn.Module):
294
+ def __init__(self, config, drop_path=0.):
295
+ super().__init__()
296
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
297
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
298
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
299
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
300
+
301
+ def forward(self, hidden_states, input_tensor):
302
+ hidden_states = self.dense(hidden_states)
303
+ hidden_states = self.dropout(hidden_states)
304
+ hidden_states = self.drop_path(hidden_states)
305
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
306
+ return hidden_states
307
+
308
+
309
+ class BertAttention(nn.Module):
310
+ def __init__(self, config, is_cross_attention=False, drop_path=0.,):
311
+ super().__init__()
312
+ self.self = BertSelfAttention(config, is_cross_attention)
313
+ self.output = BertSelfOutput(config, drop_path=drop_path)
314
+ self.pruned_heads = set()
315
+
316
+ def prune_heads(self, heads):
317
+ if len(heads) == 0:
318
+ return
319
+ heads, index = find_pruneable_heads_and_indices(
320
+ heads,
321
+ self.self.num_attention_heads,
322
+ self.self.attention_head_size,
323
+ self.pruned_heads,
324
+ )
325
+
326
+ # Prune linear layers
327
+ self.self.query = prune_linear_layer(self.self.query, index)
328
+ self.self.key = prune_linear_layer(self.self.key, index)
329
+ self.self.value = prune_linear_layer(self.self.value, index)
330
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
331
+
332
+ # Update hyper params and store pruned heads
333
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
334
+ self.self.all_head_size = (
335
+ self.self.attention_head_size * self.self.num_attention_heads
336
+ )
337
+ self.pruned_heads = self.pruned_heads.union(heads)
338
+
339
+ def forward(
340
+ self,
341
+ hidden_states,
342
+ attention_mask=None,
343
+ head_mask=None,
344
+ encoder_hidden_states=None,
345
+ encoder_attention_mask=None,
346
+ past_key_value=None,
347
+ output_attentions=False,
348
+ ):
349
+ self_outputs = self.self(
350
+ hidden_states,
351
+ attention_mask,
352
+ head_mask,
353
+ encoder_hidden_states,
354
+ encoder_attention_mask,
355
+ past_key_value,
356
+ output_attentions,
357
+ )
358
+ attention_output = self.output(self_outputs[0], hidden_states)
359
+
360
+ outputs = (attention_output,) + self_outputs[
361
+ 1:
362
+ ] # add attentions if we output them
363
+ return outputs
364
+
365
+
366
+ class BertIntermediate(nn.Module):
367
+ def __init__(self, config):
368
+ super().__init__()
369
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
370
+ if isinstance(config.hidden_act, str):
371
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
372
+ else:
373
+ self.intermediate_act_fn = config.hidden_act
374
+
375
+ def forward(self, hidden_states):
376
+ hidden_states = self.dense(hidden_states)
377
+ hidden_states = self.intermediate_act_fn(hidden_states)
378
+ return hidden_states
379
+
380
+
381
+ class BertOutput(nn.Module):
382
+ def __init__(self, config, drop_path=0.):
383
+ super().__init__()
384
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
385
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
386
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
387
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
388
+
389
+ def forward(self, hidden_states, input_tensor):
390
+ hidden_states = self.dense(hidden_states)
391
+ hidden_states = self.dropout(hidden_states)
392
+ hidden_states = self.drop_path(hidden_states)
393
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
394
+ return hidden_states
395
+
396
+
397
+ class BertLayer(nn.Module):
398
+ def __init__(self, config, layer_num):
399
+ super().__init__()
400
+ self.config = config
401
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
402
+ self.seq_len_dim = 1
403
+ drop_path = config.drop_path_list[layer_num]
404
+ self.attention = BertAttention(config, drop_path=drop_path)
405
+ self.layer_num = layer_num
406
+ if (
407
+ self.config.add_cross_attention
408
+ and layer_num % self.config.cross_attention_freq == 0
409
+ ):
410
+ self.crossattention = BertAttention(
411
+ config, is_cross_attention=self.config.add_cross_attention,
412
+ drop_path=drop_path
413
+ )
414
+ self.has_cross_attention = True
415
+ else:
416
+ self.has_cross_attention = False
417
+ self.intermediate = BertIntermediate(config)
418
+ self.output = BertOutput(config, drop_path=drop_path)
419
+
420
+ self.intermediate_query = BertIntermediate(config)
421
+ self.output_query = BertOutput(config, drop_path=drop_path)
422
+
423
+ def forward(
424
+ self,
425
+ hidden_states,
426
+ attention_mask=None,
427
+ head_mask=None,
428
+ encoder_hidden_states=None,
429
+ encoder_attention_mask=None,
430
+ past_key_value=None,
431
+ output_attentions=False,
432
+ query_length=0,
433
+ ):
434
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
435
+ self_attn_past_key_value = (
436
+ past_key_value[:2] if past_key_value is not None else None
437
+ )
438
+ self_attention_outputs = self.attention(
439
+ hidden_states,
440
+ attention_mask,
441
+ head_mask,
442
+ output_attentions=output_attentions,
443
+ past_key_value=self_attn_past_key_value,
444
+ )
445
+ attention_output = self_attention_outputs[0]
446
+ outputs = self_attention_outputs[1:-1]
447
+
448
+ present_key_value = self_attention_outputs[-1]
449
+
450
+ if query_length > 0:
451
+ query_attention_output = attention_output[:, :query_length, :]
452
+
453
+ if self.has_cross_attention:
454
+ assert (
455
+ encoder_hidden_states is not None
456
+ ), "encoder_hidden_states must be given for cross-attention layers"
457
+ cross_attention_outputs = self.crossattention(
458
+ query_attention_output,
459
+ attention_mask,
460
+ head_mask,
461
+ encoder_hidden_states,
462
+ encoder_attention_mask,
463
+ output_attentions=output_attentions,
464
+ )
465
+ query_attention_output = cross_attention_outputs[0]
466
+ outputs = (
467
+ outputs + cross_attention_outputs[1:-1]
468
+ ) # add cross attentions if we output attention weights
469
+
470
+ layer_output = apply_chunking_to_forward(
471
+ self.feed_forward_chunk_query,
472
+ self.chunk_size_feed_forward,
473
+ self.seq_len_dim,
474
+ query_attention_output,
475
+ )
476
+ if attention_output.shape[1] > query_length:
477
+ layer_output_text = apply_chunking_to_forward(
478
+ self.feed_forward_chunk,
479
+ self.chunk_size_feed_forward,
480
+ self.seq_len_dim,
481
+ attention_output[:, query_length:, :],
482
+ )
483
+ layer_output = torch.cat([layer_output, layer_output_text], dim=1)
484
+ else:
485
+ layer_output = apply_chunking_to_forward(
486
+ self.feed_forward_chunk,
487
+ self.chunk_size_feed_forward,
488
+ self.seq_len_dim,
489
+ attention_output,
490
+ )
491
+ outputs = (layer_output,) + outputs
492
+
493
+ outputs = outputs + (present_key_value,)
494
+
495
+ return outputs
496
+
497
+ def feed_forward_chunk(self, attention_output):
498
+ intermediate_output = self.intermediate(attention_output)
499
+ layer_output = self.output(intermediate_output, attention_output)
500
+ return layer_output
501
+
502
+ def feed_forward_chunk_query(self, attention_output):
503
+ intermediate_output = self.intermediate_query(attention_output)
504
+ layer_output = self.output_query(intermediate_output, attention_output)
505
+ return layer_output
506
+
507
+
508
+ class BertEncoder(nn.Module):
509
+ def __init__(self, config):
510
+ super().__init__()
511
+ self.config = config
512
+ self.layer = nn.ModuleList(
513
+ [BertLayer(config, i) for i in range(config.num_hidden_layers)]
514
+ )
515
+
516
+ def forward(
517
+ self,
518
+ hidden_states,
519
+ attention_mask=None,
520
+ head_mask=None,
521
+ encoder_hidden_states=None,
522
+ encoder_attention_mask=None,
523
+ past_key_values=None,
524
+ use_cache=None,
525
+ output_attentions=False,
526
+ output_hidden_states=False,
527
+ return_dict=True,
528
+ query_length=0,
529
+ ):
530
+ all_hidden_states = () if output_hidden_states else None
531
+ all_self_attentions = () if output_attentions else None
532
+ all_cross_attentions = (
533
+ () if output_attentions and self.config.add_cross_attention else None
534
+ )
535
+
536
+ next_decoder_cache = () if use_cache else None
537
+
538
+ for i in range(self.config.num_hidden_layers):
539
+ layer_module = self.layer[i]
540
+ if output_hidden_states:
541
+ all_hidden_states = all_hidden_states + (hidden_states,)
542
+
543
+ layer_head_mask = head_mask[i] if head_mask is not None else None
544
+ past_key_value = past_key_values[i] if past_key_values is not None else None
545
+
546
+ if getattr(self.config, "gradient_checkpointing", False) and self.training:
547
+
548
+ if use_cache:
549
+ logger.warn(
550
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
551
+ )
552
+ use_cache = False
553
+
554
+ def create_custom_forward(module):
555
+ def custom_forward(*inputs):
556
+ return module(
557
+ *inputs, past_key_value, output_attentions, query_length
558
+ )
559
+
560
+ return custom_forward
561
+
562
+ layer_outputs = torch.utils.checkpoint.checkpoint(
563
+ create_custom_forward(layer_module),
564
+ hidden_states,
565
+ attention_mask,
566
+ layer_head_mask,
567
+ encoder_hidden_states,
568
+ encoder_attention_mask,
569
+ )
570
+ else:
571
+ layer_outputs = layer_module(
572
+ hidden_states,
573
+ attention_mask,
574
+ layer_head_mask,
575
+ encoder_hidden_states,
576
+ encoder_attention_mask,
577
+ past_key_value,
578
+ output_attentions,
579
+ query_length,
580
+ )
581
+
582
+ hidden_states = layer_outputs[0]
583
+ if use_cache:
584
+ next_decoder_cache += (layer_outputs[-1],)
585
+ if output_attentions:
586
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
587
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
588
+
589
+ if output_hidden_states:
590
+ all_hidden_states = all_hidden_states + (hidden_states,)
591
+
592
+ if not return_dict:
593
+ return tuple(
594
+ v
595
+ for v in [
596
+ hidden_states,
597
+ next_decoder_cache,
598
+ all_hidden_states,
599
+ all_self_attentions,
600
+ all_cross_attentions,
601
+ ]
602
+ if v is not None
603
+ )
604
+ return BaseModelOutputWithPastAndCrossAttentions(
605
+ last_hidden_state=hidden_states,
606
+ past_key_values=next_decoder_cache,
607
+ hidden_states=all_hidden_states,
608
+ attentions=all_self_attentions,
609
+ cross_attentions=all_cross_attentions,
610
+ )
611
+
612
+
613
+ class BertPooler(nn.Module):
614
+ def __init__(self, config):
615
+ super().__init__()
616
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
617
+ self.activation = nn.Tanh()
618
+
619
+ def forward(self, hidden_states):
620
+ # We "pool" the model by simply taking the hidden state corresponding
621
+ # to the first token.
622
+ first_token_tensor = hidden_states[:, 0]
623
+ pooled_output = self.dense(first_token_tensor)
624
+ pooled_output = self.activation(pooled_output)
625
+ return pooled_output
626
+
627
+
628
+ class BertPredictionHeadTransform(nn.Module):
629
+ def __init__(self, config):
630
+ super().__init__()
631
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
632
+ if isinstance(config.hidden_act, str):
633
+ self.transform_act_fn = ACT2FN[config.hidden_act]
634
+ else:
635
+ self.transform_act_fn = config.hidden_act
636
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
637
+
638
+ def forward(self, hidden_states):
639
+ hidden_states = self.dense(hidden_states)
640
+ hidden_states = self.transform_act_fn(hidden_states)
641
+ hidden_states = self.LayerNorm(hidden_states)
642
+ return hidden_states
643
+
644
+
645
+ class BertLMPredictionHead(nn.Module):
646
+ def __init__(self, config):
647
+ super().__init__()
648
+ self.transform = BertPredictionHeadTransform(config)
649
+
650
+ # The output weights are the same as the input embeddings, but there is
651
+ # an output-only bias for each token.
652
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
653
+
654
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
655
+
656
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
657
+ self.decoder.bias = self.bias
658
+
659
+ def forward(self, hidden_states):
660
+ hidden_states = self.transform(hidden_states)
661
+ hidden_states = self.decoder(hidden_states)
662
+ return hidden_states
663
+
664
+
665
+ class BertOnlyMLMHead(nn.Module):
666
+ def __init__(self, config):
667
+ super().__init__()
668
+ self.predictions = BertLMPredictionHead(config)
669
+
670
+ def forward(self, sequence_output):
671
+ prediction_scores = self.predictions(sequence_output)
672
+ return prediction_scores
673
+
674
+
675
+ class BertPreTrainedModel(PreTrainedModel):
676
+ """
677
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
678
+ models.
679
+ """
680
+
681
+ config_class = BertConfig
682
+ base_model_prefix = "bert"
683
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
684
+
685
+ def _init_weights(self, module):
686
+ """Initialize the weights"""
687
+ if isinstance(module, (nn.Linear, nn.Embedding)):
688
+ # Slightly different from the TF version which uses truncated_normal for initialization
689
+ # cf https://github.com/pytorch/pytorch/pull/5617
690
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
691
+ elif isinstance(module, nn.LayerNorm):
692
+ module.bias.data.zero_()
693
+ module.weight.data.fill_(1.0)
694
+ if isinstance(module, nn.Linear) and module.bias is not None:
695
+ module.bias.data.zero_()
696
+
697
+
698
+ class BertModel(BertPreTrainedModel):
699
+ """
700
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
701
+ cross-attention is added between the self-attention layers, following the architecture described in `Attention is
702
+ all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
703
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
704
+ argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
705
+ input to the forward pass.
706
+ """
707
+
708
+ def __init__(self, config, add_pooling_layer=False):
709
+ super().__init__(config)
710
+ self.config = config
711
+
712
+ self.embeddings = BertEmbeddings(config)
713
+
714
+ self.encoder = BertEncoder(config)
715
+
716
+ self.pooler = BertPooler(config) if add_pooling_layer else None
717
+
718
+ self.init_weights()
719
+
720
+ def get_input_embeddings(self):
721
+ return self.embeddings.word_embeddings
722
+
723
+ def set_input_embeddings(self, value):
724
+ self.embeddings.word_embeddings = value
725
+
726
+ def _prune_heads(self, heads_to_prune):
727
+ """
728
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
729
+ class PreTrainedModel
730
+ """
731
+ for layer, heads in heads_to_prune.items():
732
+ self.encoder.layer[layer].attention.prune_heads(heads)
733
+
734
+ def get_extended_attention_mask(
735
+ self,
736
+ attention_mask: Tensor,
737
+ input_shape: Tuple[int],
738
+ device: device,
739
+ is_decoder: bool,
740
+ has_query: bool = False,
741
+ ) -> Tensor:
742
+ """
743
+ Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
744
+
745
+ Arguments:
746
+ attention_mask (:obj:`torch.Tensor`):
747
+ Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
748
+ input_shape (:obj:`Tuple[int]`):
749
+ The shape of the input to the model.
750
+ device: (:obj:`torch.device`):
751
+ The device of the input to the model.
752
+
753
+ Returns:
754
+ :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
755
+ """
756
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
757
+ # ourselves in which case we just need to make it broadcastable to all heads.
758
+ if attention_mask.dim() == 3:
759
+ extended_attention_mask = attention_mask[:, None, :, :]
760
+ elif attention_mask.dim() == 2:
761
+ # Provided a padding mask of dimensions [batch_size, seq_length]
762
+ # - if the model is a decoder, apply a causal mask in addition to the padding mask
763
+ # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
764
+ if is_decoder:
765
+ batch_size, seq_length = input_shape
766
+
767
+ seq_ids = torch.arange(seq_length, device=device)
768
+ causal_mask = (
769
+ seq_ids[None, None, :].repeat(batch_size, seq_length, 1)
770
+ <= seq_ids[None, :, None]
771
+ )
772
+
773
+ # add a prefix ones mask to the causal mask
774
+ # causal and attention masks must have same type with pytorch version < 1.3
775
+ causal_mask = causal_mask.to(attention_mask.dtype)
776
+
777
+ if causal_mask.shape[1] < attention_mask.shape[1]:
778
+ prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
779
+ if has_query: # UniLM style attention mask
780
+ causal_mask = torch.cat(
781
+ [
782
+ torch.zeros(
783
+ (batch_size, prefix_seq_len, seq_length),
784
+ device=device,
785
+ dtype=causal_mask.dtype,
786
+ ),
787
+ causal_mask,
788
+ ],
789
+ axis=1,
790
+ )
791
+ causal_mask = torch.cat(
792
+ [
793
+ torch.ones(
794
+ (batch_size, causal_mask.shape[1], prefix_seq_len),
795
+ device=device,
796
+ dtype=causal_mask.dtype,
797
+ ),
798
+ causal_mask,
799
+ ],
800
+ axis=-1,
801
+ )
802
+ extended_attention_mask = (
803
+ causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
804
+ )
805
+ else:
806
+ extended_attention_mask = attention_mask[:, None, None, :]
807
+ else:
808
+ raise ValueError(
809
+ "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
810
+ input_shape, attention_mask.shape
811
+ )
812
+ )
813
+
814
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
815
+ # masked positions, this operation will create a tensor which is 0.0 for
816
+ # positions we want to attend and -10000.0 for masked positions.
817
+ # Since we are adding it to the raw scores before the softmax, this is
818
+ # effectively the same as removing these entirely.
819
+ extended_attention_mask = extended_attention_mask.to(
820
+ dtype=self.dtype
821
+ ) # fp16 compatibility
822
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
823
+ return extended_attention_mask
824
+
825
+ def forward(
826
+ self,
827
+ input_ids=None,
828
+ attention_mask=None,
829
+ position_ids=None,
830
+ head_mask=None,
831
+ query_embeds=None,
832
+ encoder_hidden_states=None,
833
+ encoder_attention_mask=None,
834
+ past_key_values=None,
835
+ use_cache=None,
836
+ output_attentions=None,
837
+ output_hidden_states=None,
838
+ return_dict=None,
839
+ is_decoder=False,
840
+ ):
841
+ r"""
842
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
843
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
844
+ the model is configured as a decoder.
845
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
846
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
847
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
848
+ - 1 for tokens that are **not masked**,
849
+ - 0 for tokens that are **masked**.
850
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
851
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
852
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
853
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
854
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
855
+ use_cache (:obj:`bool`, `optional`):
856
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
857
+ decoding (see :obj:`past_key_values`).
858
+ """
859
+ output_attentions = (
860
+ output_attentions
861
+ if output_attentions is not None
862
+ else self.config.output_attentions
863
+ )
864
+ output_hidden_states = (
865
+ output_hidden_states
866
+ if output_hidden_states is not None
867
+ else self.config.output_hidden_states
868
+ )
869
+ return_dict = (
870
+ return_dict if return_dict is not None else self.config.use_return_dict
871
+ )
872
+
873
+ # use_cache = use_cache if use_cache is not None else self.config.use_cache
874
+
875
+ if input_ids is None:
876
+ assert (
877
+ query_embeds is not None
878
+ ), "You have to specify query_embeds when input_ids is None"
879
+
880
+ # past_key_values_length
881
+ past_key_values_length = (
882
+ past_key_values[0][0].shape[2] - self.config.query_length
883
+ if past_key_values is not None
884
+ else 0
885
+ )
886
+
887
+ query_length = query_embeds.shape[1] if query_embeds is not None else 0
888
+
889
+ embedding_output = self.embeddings(
890
+ input_ids=input_ids,
891
+ position_ids=position_ids,
892
+ query_embeds=query_embeds,
893
+ past_key_values_length=past_key_values_length,
894
+ )
895
+
896
+ input_shape = embedding_output.size()[:-1]
897
+ batch_size, seq_length = input_shape
898
+ device = embedding_output.device
899
+
900
+ if attention_mask is None:
901
+ attention_mask = torch.ones(
902
+ ((batch_size, seq_length + past_key_values_length)), device=device
903
+ )
904
+
905
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
906
+ # ourselves in which case we just need to make it broadcastable to all heads.
907
+ if is_decoder:
908
+ extended_attention_mask = self.get_extended_attention_mask(
909
+ attention_mask,
910
+ input_ids.shape,
911
+ device,
912
+ is_decoder,
913
+ has_query=(query_embeds is not None),
914
+ )
915
+ else:
916
+ extended_attention_mask = self.get_extended_attention_mask(
917
+ attention_mask, input_shape, device, is_decoder
918
+ )
919
+
920
+ # If a 2D or 3D attention mask is provided for the cross-attention
921
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
922
+ if encoder_hidden_states is not None:
923
+ if type(encoder_hidden_states) == list:
924
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[
925
+ 0
926
+ ].size()
927
+ else:
928
+ (
929
+ encoder_batch_size,
930
+ encoder_sequence_length,
931
+ _,
932
+ ) = encoder_hidden_states.size()
933
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
934
+
935
+ if type(encoder_attention_mask) == list:
936
+ encoder_extended_attention_mask = [
937
+ self.invert_attention_mask(mask) for mask in encoder_attention_mask
938
+ ]
939
+ elif encoder_attention_mask is None:
940
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
941
+ encoder_extended_attention_mask = self.invert_attention_mask(
942
+ encoder_attention_mask
943
+ )
944
+ else:
945
+ encoder_extended_attention_mask = self.invert_attention_mask(
946
+ encoder_attention_mask
947
+ )
948
+ else:
949
+ encoder_extended_attention_mask = None
950
+
951
+ # Prepare head mask if needed
952
+ # 1.0 in head_mask indicate we keep the head
953
+ # attention_probs has shape bsz x n_heads x N x N
954
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
955
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
956
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
957
+
958
+ encoder_outputs = self.encoder(
959
+ embedding_output,
960
+ attention_mask=extended_attention_mask,
961
+ head_mask=head_mask,
962
+ encoder_hidden_states=encoder_hidden_states,
963
+ encoder_attention_mask=encoder_extended_attention_mask,
964
+ past_key_values=past_key_values,
965
+ use_cache=use_cache,
966
+ output_attentions=output_attentions,
967
+ output_hidden_states=output_hidden_states,
968
+ return_dict=return_dict,
969
+ query_length=query_length,
970
+ )
971
+ sequence_output = encoder_outputs[0]
972
+ pooled_output = (
973
+ self.pooler(sequence_output) if self.pooler is not None else None
974
+ )
975
+
976
+ if not return_dict:
977
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
978
+
979
+ return BaseModelOutputWithPoolingAndCrossAttentions(
980
+ last_hidden_state=sequence_output,
981
+ pooler_output=pooled_output,
982
+ past_key_values=encoder_outputs.past_key_values,
983
+ hidden_states=encoder_outputs.hidden_states,
984
+ attentions=encoder_outputs.attentions,
985
+ cross_attentions=encoder_outputs.cross_attentions,
986
+ )
987
+
988
+
989
+ class BertLMHeadModel(BertPreTrainedModel):
990
+
991
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
992
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
993
+
994
+ def __init__(self, config):
995
+ super().__init__(config)
996
+
997
+ self.bert = BertModel(config, add_pooling_layer=False)
998
+ self.cls = BertOnlyMLMHead(config)
999
+
1000
+ self.init_weights()
1001
+
1002
+ def get_output_embeddings(self):
1003
+ return self.cls.predictions.decoder
1004
+
1005
+ def set_output_embeddings(self, new_embeddings):
1006
+ self.cls.predictions.decoder = new_embeddings
1007
+
1008
+ def forward(
1009
+ self,
1010
+ input_ids=None,
1011
+ attention_mask=None,
1012
+ position_ids=None,
1013
+ head_mask=None,
1014
+ query_embeds=None,
1015
+ encoder_hidden_states=None,
1016
+ encoder_attention_mask=None,
1017
+ labels=None,
1018
+ past_key_values=None,
1019
+ use_cache=True,
1020
+ output_attentions=None,
1021
+ output_hidden_states=None,
1022
+ return_dict=None,
1023
+ return_logits=False,
1024
+ is_decoder=True,
1025
+ reduction="mean",
1026
+ ):
1027
+ r"""
1028
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
1029
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
1030
+ the model is configured as a decoder.
1031
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1032
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
1033
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
1034
+ - 1 for tokens that are **not masked**,
1035
+ - 0 for tokens that are **masked**.
1036
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1037
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
1038
+ ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
1039
+ ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
1040
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
1041
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
1042
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
1043
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
1044
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
1045
+ use_cache (:obj:`bool`, `optional`):
1046
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
1047
+ decoding (see :obj:`past_key_values`).
1048
+ Returns:
1049
+ Example::
1050
+ >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
1051
+ >>> import torch
1052
+ >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
1053
+ >>> config = BertConfig.from_pretrained("bert-base-cased")
1054
+ >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
1055
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
1056
+ >>> outputs = model(**inputs)
1057
+ >>> prediction_logits = outputs.logits
1058
+ """
1059
+ return_dict = (
1060
+ return_dict if return_dict is not None else self.config.use_return_dict
1061
+ )
1062
+ if labels is not None:
1063
+ use_cache = False
1064
+ if past_key_values is not None:
1065
+ query_embeds = None
1066
+
1067
+ outputs = self.bert(
1068
+ input_ids,
1069
+ attention_mask=attention_mask,
1070
+ position_ids=position_ids,
1071
+ head_mask=head_mask,
1072
+ query_embeds=query_embeds,
1073
+ encoder_hidden_states=encoder_hidden_states,
1074
+ encoder_attention_mask=encoder_attention_mask,
1075
+ past_key_values=past_key_values,
1076
+ use_cache=use_cache,
1077
+ output_attentions=output_attentions,
1078
+ output_hidden_states=output_hidden_states,
1079
+ return_dict=return_dict,
1080
+ is_decoder=is_decoder,
1081
+ )
1082
+
1083
+ sequence_output = outputs[0]
1084
+ if query_embeds is not None:
1085
+ sequence_output = outputs[0][:, query_embeds.shape[1] :, :]
1086
+
1087
+ prediction_scores = self.cls(sequence_output)
1088
+
1089
+ if return_logits:
1090
+ return prediction_scores[:, :-1, :].contiguous()
1091
+
1092
+ lm_loss = None
1093
+ if labels is not None:
1094
+ # we are doing next-token prediction; shift prediction scores and input ids by one
1095
+ shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
1096
+ labels = labels[:, 1:].contiguous()
1097
+ loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1)
1098
+ lm_loss = loss_fct(
1099
+ shifted_prediction_scores.view(-1, self.config.vocab_size),
1100
+ labels.view(-1),
1101
+ )
1102
+ if reduction == "none":
1103
+ lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1)
1104
+
1105
+ if not return_dict:
1106
+ output = (prediction_scores,) + outputs[2:]
1107
+ return ((lm_loss,) + output) if lm_loss is not None else output
1108
+
1109
+ return CausalLMOutputWithCrossAttentions(
1110
+ loss=lm_loss,
1111
+ logits=prediction_scores,
1112
+ past_key_values=outputs.past_key_values,
1113
+ hidden_states=outputs.hidden_states,
1114
+ attentions=outputs.attentions,
1115
+ cross_attentions=outputs.cross_attentions,
1116
+ )
1117
+
1118
+ def prepare_inputs_for_generation(
1119
+ self, input_ids, query_embeds, past=None, attention_mask=None, **model_kwargs
1120
+ ):
1121
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
1122
+ if attention_mask is None:
1123
+ attention_mask = input_ids.new_ones(input_ids.shape)
1124
+ query_mask = input_ids.new_ones(query_embeds.shape[:-1])
1125
+ attention_mask = torch.cat([query_mask, attention_mask], dim=-1)
1126
+
1127
+ # cut decoder_input_ids if past is used
1128
+ if past is not None:
1129
+ input_ids = input_ids[:, -1:]
1130
+
1131
+ return {
1132
+ "input_ids": input_ids,
1133
+ "query_embeds": query_embeds,
1134
+ "attention_mask": attention_mask,
1135
+ "past_key_values": past,
1136
+ "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
1137
+ "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
1138
+ "is_decoder": True,
1139
+ }
1140
+
1141
+ def _reorder_cache(self, past, beam_idx):
1142
+ reordered_past = ()
1143
+ for layer_past in past:
1144
+ reordered_past += (
1145
+ tuple(
1146
+ past_state.index_select(0, beam_idx) for past_state in layer_past
1147
+ ),
1148
+ )
1149
+ return reordered_past
1150
+
1151
+
1152
+ class BertForMaskedLM(BertPreTrainedModel):
1153
+
1154
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
1155
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
1156
+
1157
+ def __init__(self, config):
1158
+ super().__init__(config)
1159
+
1160
+ self.bert = BertModel(config, add_pooling_layer=False)
1161
+ self.cls = BertOnlyMLMHead(config)
1162
+
1163
+ self.init_weights()
1164
+
1165
+ def get_output_embeddings(self):
1166
+ return self.cls.predictions.decoder
1167
+
1168
+ def set_output_embeddings(self, new_embeddings):
1169
+ self.cls.predictions.decoder = new_embeddings
1170
+
1171
+ def forward(
1172
+ self,
1173
+ input_ids=None,
1174
+ attention_mask=None,
1175
+ position_ids=None,
1176
+ head_mask=None,
1177
+ query_embeds=None,
1178
+ encoder_hidden_states=None,
1179
+ encoder_attention_mask=None,
1180
+ labels=None,
1181
+ output_attentions=None,
1182
+ output_hidden_states=None,
1183
+ return_dict=None,
1184
+ return_logits=False,
1185
+ is_decoder=False,
1186
+ ):
1187
+ r"""
1188
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1189
+ Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
1190
+ config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
1191
+ (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
1192
+ """
1193
+
1194
+ return_dict = (
1195
+ return_dict if return_dict is not None else self.config.use_return_dict
1196
+ )
1197
+
1198
+ outputs = self.bert(
1199
+ input_ids,
1200
+ attention_mask=attention_mask,
1201
+ position_ids=position_ids,
1202
+ head_mask=head_mask,
1203
+ query_embeds=query_embeds,
1204
+ encoder_hidden_states=encoder_hidden_states,
1205
+ encoder_attention_mask=encoder_attention_mask,
1206
+ output_attentions=output_attentions,
1207
+ output_hidden_states=output_hidden_states,
1208
+ return_dict=return_dict,
1209
+ is_decoder=is_decoder,
1210
+ )
1211
+
1212
+ if query_embeds is not None:
1213
+ sequence_output = outputs[0][:, query_embeds.shape[1] :, :]
1214
+ prediction_scores = self.cls(sequence_output)
1215
+
1216
+ if return_logits:
1217
+ return prediction_scores
1218
+
1219
+ masked_lm_loss = None
1220
+ if labels is not None:
1221
+ loss_fct = CrossEntropyLoss() # -100 index = padding token
1222
+ masked_lm_loss = loss_fct(
1223
+ prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)
1224
+ )
1225
+
1226
+ if not return_dict:
1227
+ output = (prediction_scores,) + outputs[2:]
1228
+ return (
1229
+ ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
1230
+ )
1231
+
1232
+ return MaskedLMOutput(
1233
+ loss=masked_lm_loss,
1234
+ logits=prediction_scores,
1235
+ hidden_states=outputs.hidden_states,
1236
+ attentions=outputs.attentions,
1237
+ )
1238
+
1239
+
1240
+ def build_qformer(num_query_token, vision_width,
1241
+ qformer_hidden_dropout_prob=0.1,
1242
+ qformer_attention_probs_dropout_prob=0.1,
1243
+ qformer_drop_path_rate=0.,
1244
+ bert_type="bert-base-uncased"
1245
+ ):
1246
+
1247
+ encoder_config = BertConfig.from_pretrained(bert_type)
1248
+ encoder_config.encoder_width = vision_width
1249
+ # insert cross-attention layer every other block
1250
+ encoder_config.add_cross_attention = True
1251
+ encoder_config.cross_attention_freq = 2
1252
+ encoder_config.query_length = num_query_token
1253
+ encoder_config.hidden_dropout_prob = qformer_hidden_dropout_prob
1254
+ encoder_config.attention_probs_dropout_prob = qformer_attention_probs_dropout_prob
1255
+ encoder_config.drop_path_list = [x.item() for x in torch.linspace(0, qformer_drop_path_rate, encoder_config.num_hidden_layers)]
1256
+ logger.info(f"Drop_path:{encoder_config.drop_path_list}")
1257
+ logger.info(encoder_config)
1258
+ Qformer = BertLMHeadModel(encoder_config)
1259
+ query_tokens = nn.Parameter(
1260
+ torch.zeros(1, num_query_token, encoder_config.hidden_size)
1261
+ )
1262
+ query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
1263
+ return Qformer, query_tokens
1264
+
modeling_special_token.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import transformers
2
+ DEFAULT_IMG_TOKEN = "[IMG]"
3
+ DEFAULT_IMG_END_TOKEN = "[/IMG]"
4
+
5
+ DEFAULT_IMAGE_TOKEN = "<image>"
6
+ DEFAULT_VIDEO_TOKEN = "[VIDEO]"
7
+
8
+ IMG_TOKEN = "[<IMG_PLH>]"
9
+ VID_TOKEN = "[<VID_PLH>]"
10
+
11
+ BOX_START = '<box_begin>'
12
+ ATBOXES_PLACEHOLDER = '<box_begin><boxes>'
13
+ BOXES_PLACEHOLDER = '<boxes>'
14
+ EXPR_PLACEHOLDER = '<expr>'
15
+ QUESTION_PLACEHOLDER = '<question>'
16
+ TIME_START = '<time_begin>'
17
+ TIME_PLACEHOLDER = '<temp>'
18
+ ATTEMP_PLACEHOLDER = TIME_START + TIME_PLACEHOLDER
19
+ TRACK_START='<track_begin>'
20
+ TRACK_PLACEHOLDER = '<tracking>'
21
+ TRACK_START_BOX = '<track_box>'
22
+ ATTRACK_PLACEHOLDER = TRACK_START + TRACK_PLACEHOLDER
23
+ need_template_list = ['REC', 'flickr', 'tracking', 'tracking2', 'tracking3', 'tracking4']
24
+
25
+ load_image_list = ['image', 'REC', 'flickr']
26
+ load_video_list = ['video', 'TVG', 'tracking', 'tracking2','tracking3', 'tracking4', 'TVG+HL']
27
+ special_tokens = [BOX_START, TIME_START, TIME_PLACEHOLDER, BOXES_PLACEHOLDER, TRACK_START, TRACK_PLACEHOLDER, TRACK_START_BOX]
modeling_videochate.py ADDED
@@ -0,0 +1,681 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import logging
3
+ import json
4
+ import torch
5
+ import torch.utils.checkpoint
6
+ from torch import nn
7
+ from torch.nn import MSELoss
8
+ from transformers.modeling_outputs import (
9
+ CausalLMOutputWithPast,
10
+ )
11
+ from typing import List, Optional, Tuple, Union
12
+ from transformers import LlamaForCausalLM
13
+ from transformers.modeling_outputs import (
14
+ CausalLMOutputWithPast,
15
+ )
16
+
17
+ from torch.cuda.amp import autocast as autocast
18
+ import torch.nn.functional as F
19
+
20
+ import numpy as np
21
+ from .modeling_vit import build_vit, MLP, PostProcess
22
+
23
+ from .modeling_qformer import build_qformer
24
+ from .modeling_base import BaseMLLM
25
+
26
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
27
+ logger = logging.getLogger(__name__)
28
+
29
+ import pycocotools.mask as mask_util
30
+
31
+ from .modeling_base import VID_TOKEN, IMG_TOKEN
32
+
33
+ class MultiModalLLM_PT(BaseMLLM):
34
+ def __init__(
35
+ self,
36
+ config,
37
+ _tokenizer=None
38
+ ):
39
+ super().__init__(config=config, _tokenizer=_tokenizer)
40
+ self.use_clip = False
41
+ self.num_frames = 16
42
+ self.num_clips = 1
43
+ self.token_merge_len = 4
44
+
45
+ self.per_clip_frames = self.num_frames // self.num_clips
46
+
47
+ print(self.config)
48
+ self.merge_proj = nn.Linear(
49
+ self.qformer.config.hidden_size*self.token_merge_len, self.config.hidden_size
50
+ )
51
+
52
+ if config.build_decoder:
53
+ self.track_embed = MLP(self.config.hidden_size, self.config.hidden_size, 3 * 256, 2, dropout=0)
54
+ self.track_embed_decode2 = MLP(4096, 4096, 4, 2, dropout=0)
55
+ self.temporal_embed = MLP(self.config.hidden_size, self.config.hidden_size, 2, 2, dropout=0.3)
56
+ self.action_embed = MLP(self.config.hidden_size, self.config.hidden_size, 1, 2, dropout=0.3)
57
+ self.postprocess = PostProcess()
58
+ self.track_token = nn.Parameter(torch.randn((1, 1, 4096)))
59
+ self.temporal_token = nn.Parameter(torch.randn((1, 1, 4096)))
60
+ self.box_token = nn.Parameter(torch.randn((1, 1, 4096)))
61
+
62
+
63
+ def forward(
64
+ self,
65
+ input_ids: torch.LongTensor = None,
66
+ attention_mask: Optional[torch.Tensor] = None,
67
+ labels: Optional[torch.LongTensor] = None,
68
+ image: Optional[torch.Tensor] = None,
69
+ video: Optional[torch.Tensor] = None,
70
+ instruction = None,
71
+ video_idx = None,
72
+ image_idx = None,
73
+ output_boxes = None, # REC
74
+ input_boxes = None, # tracking inputs
75
+ text_input = None,
76
+ video_info = None,
77
+ temporal_labels = None,
78
+ gt_masks = None,
79
+ sam_images = None,
80
+ size_hw = None,
81
+ path = None,
82
+ mask_path = None,
83
+ tvg_inputs = None,
84
+ tvg_targets = None,
85
+ ):
86
+ if text_input is not None:
87
+ time_instructions = self.get_clip_time_instruct(text_input)
88
+ else:
89
+ time_instructions = None
90
+ text_embeds = self.pad_text_embeds(input_ids=input_ids, image=image, video=video, return_visual=False,
91
+ video_idx=video_idx, image_idx=image_idx, instruction = instruction,
92
+ output_boxes = output_boxes, input_boxes=input_boxes, time_instructions = time_instructions)
93
+ outputs = self.lm(
94
+ inputs_embeds=text_embeds,
95
+ attention_mask=attention_mask,
96
+ labels=labels,
97
+ output_hidden_states=True,
98
+ return_dict=True,
99
+ )
100
+ loss = outputs.loss
101
+ logger.info(f'llm loss:{loss}')
102
+
103
+ if output_boxes is not None and self.use_bbox_loss:
104
+ last_hidden_states = outputs.hidden_states[-1]
105
+ pred_locs = []
106
+ for idx in range(last_hidden_states.shape[0]):
107
+ loc_positions = ( (input_ids[idx].flatten() == self.tokenizer.box_token) ).nonzero().flatten()
108
+ selected_hidden_states = last_hidden_states[idx][loc_positions]
109
+ pred_locs.append(self.loc_decoder(selected_hidden_states))
110
+ box_loss = self.box_loss(pred_locs, output_boxes)
111
+ logger.info(f'box loss:{box_loss}')
112
+ loss += box_loss
113
+
114
+ if (gt_masks is not None or input_boxes is not None) and self.use_mask_loss:
115
+ last_hidden_states = outputs.hidden_states[-1]
116
+ pred_masks = []
117
+ sam_losses = []
118
+ box_losses = []
119
+ for idx in range(last_hidden_states.shape[0]):
120
+ loc_positions = ( (input_ids[idx].flatten() == self.tokenizer.track_token) ).nonzero().flatten()
121
+ selected_hidden_states = last_hidden_states[idx][loc_positions]
122
+ embed_sam_boxes = self.track_embed(selected_hidden_states).reshape(1, 3, 256)
123
+ inference_state = self.sam.init_state_images(sam_images, size_hw[idx][0], size_hw[idx][1])
124
+
125
+ if input_boxes is not None:
126
+ gt_embeds = self.sam.get_prompt_embeding(inference_state, None, None, False, input_boxes[idx], device = text_embeds.device)
127
+ else:
128
+ input_boxes = self.find_boundaries_torch(gt_masks.squeeze(0)[:,:,:1].squeeze(2).cpu()).to(text_embeds.device)
129
+ gt_embeds = self.sam.get_prompt_embeding(inference_state, None, None, False, input_boxes, device = text_embeds.device)
130
+ pred_locs = [self.track_embed_decode2((selected_hidden_states))[0]]
131
+ target_boxes = [input_boxes[idx]]
132
+
133
+ src_boxes = pred_locs
134
+ loss_bbox = self.box_loss2(src_boxes, target_boxes)
135
+
136
+ loss_bbox = self.masked_loss(loss_bbox, 0)
137
+ box_losses.append(loss_bbox)
138
+ sam_losses.append( F.l1_loss(embed_sam_boxes, gt_embeds))
139
+
140
+ logger.info(f'refering sam loss:{sam_losses}')
141
+ sam_losses = torch.stack(sam_losses)
142
+ box_losses = torch.stack(box_losses)
143
+ loss += torch.mean(sam_losses)
144
+ loss += torch.mean(box_losses)
145
+
146
+ if tvg_inputs is not None and self.use_temporal_loss:
147
+ last_hidden_states = outputs.hidden_states[-1] # [bsz,1024, 4096]
148
+ last_hidden_states = last_hidden_states.view(-1, last_hidden_states.size(-1)) # [bsz*1024, 4096]
149
+ loc_positions = (input_ids.flatten()==self.tokenizer.temp_token).nonzero().flatten() # [bsz]
150
+ prompt_token = last_hidden_states[loc_positions]
151
+ prompt_token = prompt_token.view(input_ids.shape[0], -1 ,prompt_token.shape[-1]) # [bsz, 1, 4096]
152
+
153
+
154
+ cg_outputs = self.cg_model(**tvg_inputs, targets=tvg_targets, prompt_token=prompt_token)
155
+ loss_dict = self.cg_criterion(cg_outputs, tvg_targets)
156
+ weight_dict = self.cg_criterion.weight_dict
157
+ tvg_loss = 0.05*sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)
158
+ logger.info(f'tvg_loss:{tvg_loss}')
159
+ loss += tvg_loss
160
+
161
+
162
+ logger.info(f'all loss:{loss}')
163
+ return CausalLMOutputWithPast(
164
+ loss=loss,
165
+ logits=outputs.logits,
166
+ past_key_values=outputs.past_key_values,
167
+ hidden_states=outputs.hidden_states,
168
+ attentions=outputs.attentions,
169
+ )
170
+
171
+ def pad_text_embeds(
172
+ self,
173
+ input_ids: torch.LongTensor = None,
174
+ image: Optional[torch.Tensor] = None,
175
+ video: Optional[torch.Tensor] = None,
176
+ image_idx = None,
177
+ video_idx = None,
178
+ return_visual: bool = False,
179
+ instruction = None,
180
+ output_boxes = None, # boxes for REC
181
+ input_boxes = None, # boxes for tracking
182
+ time_instructions = None,
183
+ ):
184
+ text_embeds = self.lm.get_input_embeddings()(input_ids.long()).detach()
185
+ if input_boxes is not None:
186
+ input_boxes = input_boxes[0].to(dtype=text_embeds.dtype)
187
+
188
+ boxes_emb = self.loc_encoder(input_boxes)
189
+ boxes_emb = boxes_emb.view(-1, 4096)
190
+
191
+ text_embeds[input_ids == torch.full_like(input_ids, self.tokenizer.track_box_token)] = text_embeds[input_ids == torch.full_like(input_ids, self.tokenizer.track_box_token)] * 0 + boxes_emb.to(text_embeds.device)
192
+ logger.info(f'embedings:{text_embeds[input_ids == torch.full_like(input_ids, self.tokenizer.track_box_token)].shape}')
193
+ visual = None
194
+ visual_idx = None
195
+
196
+ if image is not None:
197
+
198
+ B, T, C, H, W = image.shape
199
+ image = image.permute(0, 2, 1, 3, 4)
200
+
201
+ instruction = None
202
+
203
+ prompt_image_embeds = self.encode_vision(image, instruction)
204
+
205
+ visual = prompt_image_embeds
206
+
207
+ prompt_image_embeds = self.project_up(prompt_image_embeds) # 768 -> 4096
208
+ prompt_image_embeds = prompt_image_embeds.view(-1, prompt_image_embeds.shape[-1])
209
+
210
+ visual_idx = image_idx
211
+
212
+ prompt_image_embeds = prompt_image_embeds.to(dtype=text_embeds.dtype)
213
+
214
+ text_embeds[image_idx == 1] = torch.zeros_like(text_embeds[image_idx == 1]) + prompt_image_embeds.to(text_embeds.device)
215
+
216
+
217
+ elif video is not None:
218
+ if len(video.shape) == 5:
219
+ B, T, C, H, W = video.shape
220
+ N = 1
221
+ if self.use_clip:
222
+ video = video.reshape(B*self.num_clips, T//self.num_clips, C, H, W) # [16, 8, 3, 224, 224]
223
+ else:
224
+ B, N, T, C, H, W = video.shape
225
+
226
+ video = video.permute(0,2,1,3,4) #
227
+
228
+
229
+ prompt_video_embeds = self.encode_vision(video, instruction=time_instructions) # [2, 96, 768]
230
+ if self.use_clip:
231
+ prompt_video_embeds = prompt_video_embeds.reshape(B,-1,prompt_video_embeds.shape[-1]) # [2,8*96,768]
232
+ batch_size, img_len, token_dim = prompt_video_embeds.shape
233
+ prompt_video_embeds = prompt_video_embeds.view(batch_size, img_len // self.token_merge_len, self.token_merge_len * token_dim) # [B, 768//4, 4*768] = [2, 192, 3072]
234
+ prompt_video_embeds = self.merge_proj(prompt_video_embeds) # [2, 192, 4096]
235
+ prompt_video_embeds = prompt_video_embeds.view(-1, prompt_video_embeds.shape[-1]) # [2*192, 4096]
236
+
237
+ else:
238
+ prompt_video_embeds = self.project_up(prompt_video_embeds) # [2, 96, 4096]
239
+
240
+ prompt_video_embeds = prompt_video_embeds.view(-1, prompt_video_embeds.shape[-1])
241
+ visual_idx = video_idx
242
+
243
+
244
+ text_embeds[video_idx == 1] = torch.zeros_like(text_embeds[video_idx == 1]) + prompt_video_embeds.to(text_embeds.device).to(text_embeds.dtype)
245
+
246
+ else:
247
+ logger.warn(f"don't get visual input, input_ids: {input_ids}")
248
+
249
+
250
+ for idx, text_embed in enumerate(text_embeds):
251
+ if text_embeds[idx][input_ids[idx].flatten() == self.tokenizer.box_token].shape[0] != 0:
252
+ text_embeds[idx][input_ids[idx].flatten() == self.tokenizer.box_token] = torch.zeros_like(text_embeds[idx][input_ids[idx] == self.tokenizer.box_token]) + torch.cat([self.box_token.squeeze(0)] * (text_embeds[idx][input_ids[idx] == self.tokenizer.box_token]).shape[0]).to(text_embeds.dtype)
253
+ if text_embeds[idx][input_ids[idx].flatten() == self.tokenizer.temp_token].shape[0] != 0:
254
+ text_embeds[idx][input_ids[idx].flatten() == self.tokenizer.temp_token] = torch.zeros_like(text_embeds[idx][input_ids[idx] == self.tokenizer.temp_token]) + self.temporal_token
255
+ if text_embeds[idx][input_ids[idx].flatten() == self.tokenizer.track_token].shape[0] != 0:
256
+ text_embeds[idx][input_ids[idx].flatten() == self.tokenizer.track_token] = torch.zeros_like(text_embeds[idx][input_ids[idx] == self.tokenizer.track_token]) + self.track_token
257
+
258
+ if return_visual:
259
+ return text_embeds, visual, visual_idx
260
+
261
+ return text_embeds
262
+
263
+
264
+
265
+ def temporal_decode(self, temporal_embedding):
266
+ pred_sted = self.temporal_embed(temporal_embedding)
267
+ pred_actioness = self.action_embed(temporal_embedding)
268
+ return pred_sted, pred_actioness
269
+
270
+
271
+ def box_loss2(self, src_boxes, target_boxes):
272
+ src_boxes = torch.cat(src_boxes)
273
+ target_boxes = torch.cat(target_boxes)
274
+
275
+ loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none')
276
+ loss_bbox = self.masked_loss(loss_bbox, 0)
277
+ mask = (src_boxes[2:] >= src_boxes[:2]).all(-1)
278
+ src_boxes = src_boxes[mask]
279
+ target_boxes = target_boxes[mask]
280
+
281
+ return loss_bbox
282
+
283
+ def box_loss(self, src_boxes, target_boxes):
284
+ src_boxes = torch.cat(src_boxes)
285
+ target_boxes = torch.cat(target_boxes)
286
+
287
+ loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none')
288
+ loss_bbox = self.masked_loss(loss_bbox, 0)
289
+ mask = (src_boxes[:, 2:] >= src_boxes[ :, :2]).all(-1)
290
+ src_boxes = src_boxes[mask]
291
+ target_boxes = target_boxes[mask]
292
+
293
+ if src_boxes.shape[0] > 0:
294
+ loss_giou = 1 - torch.diag(generalized_box_iou(
295
+ src_boxes,
296
+ target_boxes))
297
+ loss_giou = self.masked_loss(loss_giou, 0)
298
+ else:
299
+ loss_giou = torch.tensor(2, dtype=src_boxes.dtype)
300
+ iou, union = box_iou(src_boxes, target_boxes)
301
+
302
+ return loss_bbox * 2 + loss_giou / 5
303
+
304
+ def find_boundaries_torch(self, mask):
305
+
306
+ from skimage.segmentation import find_boundaries
307
+ mask_np = mask.to(torch.bool).numpy()
308
+ boundaries = find_boundaries(mask_np, mode='outer')
309
+ boundary_points = np.argwhere(boundaries)
310
+ if boundary_points.size == 0:
311
+ return torch.tensor([-1, -1, -1, -1], dtype = torch.bfloat16)
312
+ h0, w0 = boundary_points.min(axis=0)
313
+ h1, w1 = boundary_points.max(axis=0)
314
+ return torch.tensor([w0 / mask.shape[1], h0 / mask.shape[0], w1 / mask.shape[1], h1 / mask.shape[0]], dtype = torch.bfloat16)
315
+
316
+
317
+ def sam_loss(self, sam_outputs, gt_masks):
318
+ bound1 = self.find_boundaries_torch(gt_masks[:,:,:1].squeeze(2).cpu())
319
+ bound2 = self.find_boundaries_torch(sam_outputs[:,:,:1].squeeze(2).cpu())
320
+
321
+ lossl1 = F.l1_loss(bound1, bound2, reduction='none')
322
+ lossl1 = self.masked_loss(lossl1, 0)
323
+
324
+ loss_iou = self.iou_loss(sam_outputs, gt_masks)
325
+ loss_dice = self.dice_loss(sam_outputs, gt_masks)
326
+
327
+ # print(f'mask loss:{loss_iou, loss_dice}')
328
+ return loss_iou + loss_dice + lossl1
329
+
330
+ def masked_loss(self, loss, n):
331
+ mask = torch.ones_like(loss)
332
+ # mask[-n:] = 1e-10
333
+ loss = (loss*mask).sum()/(mask.sum())
334
+ return loss
335
+
336
+ def encode_vision(
337
+ self,
338
+ image,
339
+ instruction
340
+ ):
341
+ device = image.device
342
+ B = image.shape[0]
343
+ T = image.shape[2]
344
+ use_image = True if T == 1 else False
345
+ image_embeds = self.vision_encoder(image, use_image=use_image)
346
+ C = image_embeds.shape[-1]
347
+ image_embeds = image_embeds.reshape(B, -1, C)
348
+ image_embeds = self.vision_layernorm(image_embeds).to(device) # [B, T*L, C]
349
+
350
+ image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device)
351
+ if self.extra_num_query_token > 0:
352
+ query_tokens = torch.cat([self.query_tokens, self.extra_query_tokens], dim=1)
353
+ query_tokens = query_tokens.expand(image_embeds.shape[0], -1, -1)
354
+ if instruction is not None:
355
+ text_Qformer = self.qformer_tokenizer(
356
+ instruction,
357
+ padding='longest',
358
+ truncation=True,
359
+ max_length=512,
360
+ return_tensors="pt",
361
+ ).to(image_embeds.device)
362
+ query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(image_embeds.device)
363
+ Qformer_atts = torch.cat([query_atts, text_Qformer.attention_mask], dim=1)
364
+ query_output = self.qformer.bert(
365
+ text_Qformer.input_ids,
366
+ attention_mask=Qformer_atts,
367
+ query_embeds=query_tokens,
368
+ encoder_hidden_states=image_embeds,
369
+ encoder_attention_mask=image_atts,
370
+ return_dict=True,
371
+ )
372
+ else:
373
+ query_output = self.qformer.bert(
374
+ query_embeds=query_tokens,
375
+ encoder_hidden_states=image_embeds,
376
+ encoder_attention_mask=image_atts,
377
+ return_dict=True,
378
+ )
379
+
380
+ return query_output.last_hidden_state[:, :query_tokens.size(1), :]
381
+
382
+ def generate_caption(
383
+ self,
384
+ input_ids,
385
+ attention_mask,
386
+ image_idx = None,
387
+ video_idx = None,
388
+ image: Optional[torch.Tensor] = None,
389
+ video: Optional[torch.Tensor] = None,
390
+ num_beams=1,
391
+ max_new_tokens=200,
392
+ do_sample=True,
393
+ top_p=0.9,
394
+ top_k=None,
395
+ temperature=1.0,
396
+ length_penalty=1,
397
+ repetition_penalty=1.0,
398
+ ):
399
+ text_embeds = self.pad_text_embeds(input_ids=input_ids, image=image, video=video, image_idx=image_idx, video_idx=video_idx)
400
+ outputs = self.lm.generate(
401
+ inputs_embeds=text_embeds,
402
+ attention_mask=attention_mask,
403
+ num_beams=num_beams,
404
+ max_new_tokens=max_new_tokens,
405
+ do_sample=do_sample,
406
+ min_length=1,
407
+ top_p=top_p,
408
+ top_k=top_k,
409
+ temperature=temperature,
410
+ length_penalty=length_penalty,
411
+ repetition_penalty=repetition_penalty,
412
+ )
413
+
414
+ return outputs
415
+
416
+ def generate_caption_bbox(
417
+ self,
418
+ input_ids,
419
+ attention_mask,
420
+ labels,
421
+ image_idx = None,
422
+ video_idx = None,
423
+ image: Optional[torch.Tensor] = None,
424
+ video: Optional[torch.Tensor] = None,
425
+ num_beams=1,
426
+ max_new_tokens=200,
427
+ do_sample=True,
428
+ top_p=0.9,
429
+ top_k=None,
430
+ temperature=0.9,
431
+ length_penalty=1,
432
+ repetition_penalty=1.0,
433
+ ):
434
+ text_embeds = self.pad_text_embeds(input_ids=input_ids, image=image, video=video, image_idx=image_idx, video_idx=video_idx)
435
+ outputs = self.lm.generate(
436
+ inputs_embeds=text_embeds,
437
+ attention_mask=attention_mask,
438
+ num_beams=num_beams,
439
+ max_new_tokens=max_new_tokens,
440
+ do_sample=do_sample,
441
+ min_length=1,
442
+ top_p=top_p,
443
+ top_k=top_k,
444
+ temperature=temperature,
445
+ length_penalty=length_penalty,
446
+ repetition_penalty=repetition_penalty,
447
+ )
448
+ decoded_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
449
+ # torch.save({'text':decoded_text, 'output':{outputs}}, 'tmp.pth')
450
+ # print(decoded_text)
451
+ return outputs
452
+
453
+ def generate_temporal(self,
454
+ input_ids: torch.LongTensor = None,
455
+ attention_mask: Optional[torch.Tensor] = None,
456
+ labels: Optional[torch.LongTensor] = None,
457
+ image: Optional[torch.Tensor] = None,
458
+ video: Optional[torch.Tensor] = None,
459
+ instruction = None,
460
+ video_idx = None,
461
+ image_idx = None,
462
+ boxes = None,
463
+ text_input = None,
464
+ video_info = None,
465
+ temporal_labels = None):
466
+
467
+ if text_input is not None:
468
+ time_instructions = self.get_clip_time_instruct(text_input)
469
+ else:
470
+ time_instructions = None
471
+ text_embeds = self.pad_text_embeds(input_ids=input_ids, image=image, video=video, return_visual=False,
472
+ video_idx=video_idx, image_idx=image_idx, instruction = instruction,
473
+ boxes = boxes, time_instructions = time_instructions)
474
+
475
+ # TODO
476
+ outputs = self.lm(
477
+ inputs_embeds=text_embeds,
478
+ attention_mask=attention_mask,
479
+ labels=labels,
480
+ output_hidden_states=True,
481
+ return_dict=True,
482
+ )
483
+
484
+ if temporal_labels is not None:
485
+ start_sec = temporal_labels["start_sec"]
486
+ end_sec = temporal_labels["end_sec"]
487
+ fps = video_info['fps']
488
+ frame_indices = video_info['frame_indices']
489
+
490
+ last_hidden_states = outputs.hidden_states[-1] # [2,1024, 4096]
491
+ last_hidden_states = last_hidden_states.view(-1, last_hidden_states.size(-1)) # [2048, 4096]
492
+ loc_positions = (input_ids.flatten()==self.tokenizer.temp_place_ids).nonzero().flatten() #
493
+ selected_hidden_states = last_hidden_states[loc_positions]
494
+ selected_hidden_states = selected_hidden_states.view(input_ids.shape[0], -1 ,selected_hidden_states.shape[-1]) # [2, 64, 4096]
495
+
496
+ # just for debug
497
+
498
+ # vis_embed = vis_embed[:,:64,:]
499
+
500
+ pred_sted, pred_actionness = self.temporal_decode(selected_hidden_states) # [2,64,2] [2,64,1]
501
+
502
+ pred_sted = self.postprocess(pred_sted, frame_indices)
503
+ pred_sec_s = pred_sted[0][0] / fps[0][0].item()
504
+ pred_sec_e = pred_sted[0][1] / fps[0][0].item()
505
+
506
+ output_file = "predictions2.jsonl"
507
+ prediction = {"pred_sec_s": round(pred_sec_s, 1), "pred_sec_e": round(pred_sec_e, 1), "start_sec":float(start_sec[0]), "end_sec": float(end_sec[0])}
508
+
509
+ with open(output_file, 'a') as f:
510
+ json.dump(prediction, f)
511
+ f.write('\n')
512
+
513
+ return outputs
514
+
515
+ def generate_seg(self, input_ids, attention_mask, labels, image, image_idx, video, video_idx, input_boxes, size_hw, sam_images):
516
+ device = input_ids.device
517
+ prompt = input_ids
518
+ l_prompt = len(input_ids)
519
+ temperature = 1e-5
520
+ max_new_tokens = 20
521
+ guide_w = 5
522
+ stop_str = '</s>'
523
+ bbox = []
524
+ output_ids = list(input_ids[0])
525
+ text_embeds = self.pad_text_embeds(input_ids=input_ids, image=image, video=video, image_idx=image_idx, video_idx=video_idx, return_visual=False,
526
+ instruction = None, output_boxes=None, input_boxes=input_boxes)
527
+ for i in range(max_new_tokens):
528
+ if i == 0:
529
+ outputs = self.lm(
530
+ inputs_embeds=text_embeds,
531
+ attention_mask=attention_mask,
532
+ output_hidden_states=True,
533
+ return_dict=True,
534
+ )
535
+ logits = outputs.logits
536
+ past_key_values = outputs.past_key_values
537
+ else:
538
+ attention_mask = torch.ones(1, past_key_values[0][0].shape[-2] + 1, device=device)
539
+ last_text_embeds = self.lm.get_input_embeddings()(torch.tensor(output_ids[-1], device=device).long()).detach().unsqueeze(0)
540
+ last_text_embeds = last_text_embeds.unsqueeze(0)
541
+
542
+ out = self.lm(
543
+ input_ids=None,
544
+ use_cache=True,
545
+ attention_mask=attention_mask,
546
+ output_hidden_states=True,
547
+ inputs_embeds=last_text_embeds,
548
+ past_key_values=past_key_values,
549
+ )
550
+ logits = out.logits
551
+ past_key_values = out.past_key_values
552
+ if logits is not None:
553
+ last_token_logits = logits[0][-1]
554
+ if temperature < 1e-4:
555
+ token = int(torch.argmax(last_token_logits))
556
+ else:
557
+ probs = torch.softmax(last_token_logits / temperature, dim=-1)
558
+ token = int(torch.multinomial(probs, num_samples=1))
559
+ output_ids.append(token)
560
+ ret = self.tokenizer.decode(token)
561
+ if ret == '<box_begin>':
562
+ attention_mask = torch.ones(1, past_key_values[0][0].shape[-2] + 1, device=device)
563
+ bbox_embeds = self.box_token.bfloat16()
564
+ out = self.lm(
565
+ inputs_embeds=bbox_embeds,
566
+ use_cache=True,
567
+ attention_mask=attention_mask,
568
+ output_hidden_states=True,
569
+ past_key_values=past_key_values
570
+ )
571
+ last_hidden_states = out.hidden_states[-1]
572
+ selected_hidden_states = last_hidden_states[0][0]
573
+ bbox.append(self.loc_decoder(selected_hidden_states))
574
+ last_token_logits = logits[0][-1]
575
+ if temperature < 1e-4:
576
+ token = int(torch.argmax(last_token_logits))
577
+ else:
578
+ probs = torch.softmax(last_token_logits / temperature, dim=-1)
579
+ token = int(torch.multinomial(probs, num_samples=1))
580
+ if ret == '<track_begin>':
581
+ attention_mask = torch.ones(1, past_key_values[0][0].shape[-2] + 1, device=device)
582
+ tracking_embeds = self.track_token
583
+ out = self.lm(
584
+ inputs_embeds=tracking_embeds,
585
+ use_cache=True,
586
+ attention_mask=attention_mask,
587
+ output_hidden_states=True,
588
+ past_key_values=past_key_values
589
+ )
590
+ last_hidden_states = out.hidden_states[-1]
591
+ selected_hidden_states = last_hidden_states[0][0].to(dtype = torch.bfloat16)
592
+
593
+ embed_sam_boxes = self.track_embed(selected_hidden_states).reshape(1, 3, 256)
594
+
595
+ inference_state = self.sam.init_state_images(sam_images, size_hw[0][0], size_hw[0][1])
596
+ gt_embeds = self.sam.get_prompt_embeding(inference_state, None, None, False, input_boxes[0].cuda(), device = text_embeds.device)
597
+ ann_frame_idx = 0
598
+ ann_obj_id = 0
599
+ box = np.array([0, 0, 0, 0], dtype=np.float32)
600
+ _, out_obj_ids, out_mask_logits = self.sam.add_new_box_embeding(
601
+ inference_state=inference_state,
602
+ frame_idx=ann_frame_idx,
603
+ obj_id=ann_obj_id,
604
+ box=box,
605
+ box_embeding=embed_sam_boxes,
606
+ )
607
+ video_segments = {} # video_segments contains the per-frame segmentation results
608
+ for out_frame_idx, out_obj_ids, out_mask_logits in self.sam.propagate_in_video(inference_state):
609
+ video_segments[out_frame_idx] = {
610
+ out_obj_id: (out_mask_logits[i] > 0.0)
611
+ for i, out_obj_id in enumerate(out_obj_ids)
612
+ }
613
+ video_segments = [video_segments[tt][0] for tt in video_segments]
614
+ # bbox = model.find_boundaries_torch(video_segments[0].squeeze(0).cpu())
615
+ # return ret, [], video_segments
616
+
617
+ if (ret == '</s>'):
618
+ break
619
+ ret = self.tokenizer.decode(output_ids)
620
+ del past_key_values
621
+ return ret, bbox, video_segments
622
+
623
+ def generate_answer(self, tokenizer, instruction, msg, user_prompt, media_type="video",video_tensor=None, image_tensor=None, answer_prompt=None, chat_history=[],return_history=False, debug=False, generation_config={}):
624
+ input_ids, attention_masks, labels = [], [], []
625
+
626
+ conversation = ""
627
+ if instruction:
628
+ conversation += instruction
629
+ conversation += (
630
+ "[INST]" + " "
631
+ )
632
+
633
+ if media_type == 'image':
634
+ conversation +=( "<Image>" + IMG_TOKEN + "</Image>")
635
+ else:
636
+ conversation += ("<Video>" + VID_TOKEN + "</Video>")
637
+
638
+ conversation += ( msg.rstrip() + "[/INST]")
639
+
640
+ for q,a in chat_history:
641
+ conversation += (" [INST] " + q + " [/INST]")
642
+ conversation += (a + "</s>")
643
+
644
+ conversation += (" [INST] " + user_prompt + " [/INST]")
645
+ conversation += ("")
646
+ if answer_prompt:
647
+ conversation += ("Best Option: (")
648
+ total_len = 0
649
+ indexs = []
650
+ if debug:
651
+ print(conversation)
652
+
653
+ tokenized = tokenizer.build_input_ids([conversation],
654
+ max_length=1024,
655
+ add_special_tokens=True,
656
+ truncation=False,
657
+ padding=False,
658
+ return_tensors='pt',
659
+ image=image_tensor,
660
+ video=video_tensor,
661
+ require_video=True)
662
+ if video_tensor is not None:
663
+ generation_output = self.generate_caption(
664
+ tokenized['input_ids'].unsqueeze(0).to(self.device),
665
+ tokenized['attention_mask'].unsqueeze(0).to(self.device),
666
+ video_idx = tokenized['video_index'].unsqueeze(0),
667
+ video = video_tensor.unsqueeze(0).to(self.device,dtype=torch.bfloat16),
668
+ do_sample=False
669
+ )
670
+ elif image_tensor is not None:
671
+ generation_output = self.generate_caption(
672
+ tokenized['input_ids'].unsqueeze(0).to(self.device),
673
+ tokenized['attention_mask'].unsqueeze(0).to(self.device),
674
+ image_idx = tokenized['image_index'].unsqueeze(0),
675
+ image = image_tensor.unsqueeze(0).to(self.device,dtype=torch.bfloat16),
676
+ do_sample=False
677
+ )
678
+ response = tokenizer.batch_decode(generation_output, skip_special_tokens=True)[0]
679
+ if debug:
680
+ print(response)
681
+ return response, chat_history
modeling_vit.py ADDED
@@ -0,0 +1,487 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import torch.utils.checkpoint as checkpoint
7
+ from functools import partial
8
+
9
+ from timm.models.layers import drop_path, to_2tuple, trunc_normal_
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ def _cfg(url='', **kwargs):
15
+ return {
16
+ 'url': url,
17
+ 'num_classes': 400, 'input_size': (3, 224, 224), 'pool_size': None,
18
+ 'crop_pct': .9, 'interpolation': 'bicubic',
19
+ 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
20
+ **kwargs
21
+ }
22
+
23
+ class MLP(nn.Module):
24
+ """Very simple multi-layer perceptron (also called FFN)"""
25
+
26
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers, dropout=0):
27
+ super().__init__()
28
+ self.num_layers = num_layers
29
+ h = [hidden_dim] * (num_layers - 1)
30
+ self.layers = nn.ModuleList(
31
+ nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
32
+ )
33
+ self.dropout = dropout
34
+ if dropout:
35
+ self.dropout = nn.Dropout(dropout)
36
+
37
+ def forward(self, x):
38
+ for i, layer in enumerate(self.layers):
39
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
40
+ if self.dropout and i < self.num_layers:
41
+ x = self.dropout(x)
42
+ return x
43
+
44
+ class PostProcess(nn.Module):
45
+ """ This module converts the model's output into the format expected by the coco api"""
46
+
47
+ @torch.no_grad()
48
+ def forward(self, out_sted, frames_id):
49
+ """Perform the computation for inference evaluation
50
+ """
51
+ # import pdb; pdb.set_trace()
52
+
53
+ b, t, _ = out_sted.shape
54
+ device = out_sted.device
55
+ temp_prob_map = torch.zeros(b,t,t).to(device)
56
+ inf = -1e32
57
+ for i_b in range(len(frames_id)):
58
+ duration = len(frames_id[0])
59
+ sted_prob = (torch.ones(t, t) * inf).tril(0).to(device)
60
+ sted_prob[duration:,:] = inf
61
+ sted_prob[:,duration:] = inf
62
+ temp_prob_map[i_b,:,:] = sted_prob
63
+
64
+ temp_prob_map += F.log_softmax(out_sted[:, :, 0], dim=1).unsqueeze(2) + \
65
+ F.log_softmax(out_sted[:, :, 1], dim=1).unsqueeze(1)
66
+
67
+ pred_steds = []
68
+ for i_b in range(b):
69
+ prob_map = temp_prob_map[i_b] # [T * T]
70
+ frame_id_seq = frames_id[i_b]
71
+ prob_seq = prob_map.flatten(0)
72
+ max_tstamp = prob_seq.max(dim=0)[1].item()
73
+ start_idx = max_tstamp // t
74
+ end_idx = max_tstamp % t
75
+ pred_sted = [frame_id_seq[start_idx], frame_id_seq[end_idx]+1]
76
+ pred_steds.append(pred_sted)
77
+
78
+ return pred_steds
79
+
80
+ class DropPath(nn.Module):
81
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
82
+ """
83
+ def __init__(self, drop_prob=None):
84
+ super(DropPath, self).__init__()
85
+ self.drop_prob = drop_prob
86
+
87
+ def forward(self, x):
88
+ return drop_path(x, self.drop_prob, self.training)
89
+
90
+ def extra_repr(self) -> str:
91
+ return 'p={}'.format(self.drop_prob)
92
+
93
+
94
+ class Mlp(nn.Module):
95
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
96
+ super().__init__()
97
+ out_features = out_features or in_features
98
+ hidden_features = hidden_features or in_features
99
+ self.fc1 = nn.Linear(in_features, hidden_features)
100
+ self.act = act_layer()
101
+ self.fc2 = nn.Linear(hidden_features, out_features)
102
+ self.drop = nn.Dropout(drop)
103
+
104
+ def forward(self, x):
105
+ x = self.fc1(x)
106
+ x = self.act(x)
107
+ x = self.drop(x)
108
+ x = self.fc2(x)
109
+ x = self.drop(x)
110
+ return x
111
+
112
+
113
+ class Attention(nn.Module):
114
+ def __init__(
115
+ self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
116
+ proj_drop=0., attn_head_dim=None):
117
+ super().__init__()
118
+ self.num_heads = num_heads
119
+ head_dim = dim // num_heads
120
+ if attn_head_dim is not None:
121
+ head_dim = attn_head_dim
122
+ all_head_dim = head_dim * self.num_heads
123
+ self.scale = qk_scale or head_dim ** -0.5
124
+
125
+ self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
126
+ if qkv_bias:
127
+ self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
128
+ self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
129
+ else:
130
+ self.q_bias = None
131
+ self.v_bias = None
132
+
133
+ self.attn_drop = nn.Dropout(attn_drop)
134
+ self.proj = nn.Linear(all_head_dim, dim)
135
+ self.proj_drop = nn.Dropout(proj_drop)
136
+
137
+ def forward(self, x):
138
+ B, N, C = x.shape
139
+ qkv_bias = None
140
+ if self.q_bias is not None:
141
+ qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
142
+ # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
143
+ qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
144
+ qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
145
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
146
+
147
+ q = q * self.scale
148
+ attn = (q @ k.transpose(-2, -1))
149
+
150
+ attn = attn.softmax(dim=-1)
151
+ attn = self.attn_drop(attn)
152
+
153
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
154
+ x = self.proj(x)
155
+ x = self.proj_drop(x)
156
+ return x
157
+
158
+
159
+ class Block(nn.Module):
160
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
161
+ drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
162
+ attn_head_dim=None):
163
+ super().__init__()
164
+ self.norm1 = norm_layer(dim)
165
+ self.attn = Attention(
166
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
167
+ attn_drop=attn_drop, proj_drop=drop, attn_head_dim=attn_head_dim)
168
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
169
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
170
+ self.norm2 = norm_layer(dim)
171
+ mlp_hidden_dim = int(dim * mlp_ratio)
172
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
173
+
174
+ if init_values > 0:
175
+ self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
176
+ self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
177
+ else:
178
+ self.gamma_1, self.gamma_2 = None, None
179
+
180
+ def forward(self, x):
181
+ if self.gamma_1 is None:
182
+ x = x + self.drop_path(self.attn(self.norm1(x)))
183
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
184
+ else:
185
+ x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x)))
186
+ x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
187
+ return x
188
+
189
+
190
+ class PatchEmbed(nn.Module):
191
+ """ Image to Patch Embedding
192
+ """
193
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, num_frames=16, tubelet_size=2):
194
+ super().__init__()
195
+ img_size = to_2tuple(img_size)
196
+ patch_size = to_2tuple(patch_size)
197
+ self.tubelet_size = int(tubelet_size)
198
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) * (num_frames // self.tubelet_size)
199
+ self.img_size = img_size
200
+ self.patch_size = patch_size
201
+ self.num_patches = num_patches
202
+ self.proj = nn.Conv3d(
203
+ in_channels=in_chans, out_channels=embed_dim,
204
+ kernel_size=(self.tubelet_size, patch_size[0], patch_size[1]),
205
+ stride=(self.tubelet_size, patch_size[0], patch_size[1])
206
+ )
207
+ logger.info(f'Num of patches: {num_patches}')
208
+
209
+ def forward(self, x, **kwargs):
210
+ B, C, T, H, W = x.shape
211
+ # FIXME look at relaxing size constraints
212
+ # assert H == self.img_size[0] and W == self.img_size[1], \
213
+ # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
214
+ x = self.proj(x).flatten(2).transpose(1, 2)
215
+ return x
216
+
217
+ # sin-cos position encoding
218
+ # https://github.com/jadore801120/attention-is-all-you-need-pytorch/blob/master/transformer/Models.py#L31
219
+ def get_sinusoid_encoding_table(n_position, d_hid, ckpt_num_frame=-1, cur_frame=12):
220
+ ''' Sinusoid position encoding table '''
221
+ # TODO: make it with torch instead of numpy
222
+ def get_position_angle_vec(position):
223
+ return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
224
+
225
+ if ckpt_num_frame != -1 and ckpt_num_frame != cur_frame:
226
+ logger.info(f"Interpolate position embedding")
227
+ logger.info(f"Testing frame: {cur_frame}")
228
+ logger.info(f"Checkpoint frame: {ckpt_num_frame}")
229
+
230
+ T = ckpt_num_frame # checkpoint frame
231
+ new_T = cur_frame # testing frame
232
+ n_position = n_position // new_T * T # generate checkpoint position embedding
233
+ sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
234
+ sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
235
+ sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
236
+ sinusoid_table = torch.tensor(sinusoid_table, dtype=torch.float, requires_grad=False).unsqueeze(0)
237
+ # interpolate
238
+ P = int((n_position // T) ** 0.5)
239
+ C = d_hid
240
+ sinusoid_table = sinusoid_table.reshape(-1, T, P, P, C)
241
+ sinusoid_table = sinusoid_table.permute(0, 2, 3, 4, 1).reshape(-1, C, T) # BHW, C, T
242
+ sinusoid_table = torch.nn.functional.interpolate(sinusoid_table, size=new_T, mode='linear')
243
+ sinusoid_table = sinusoid_table.reshape(1, P, P, C, new_T).permute(0, 4, 1, 2, 3) # B, T, H, W, C
244
+ sinusoid_table = sinusoid_table.flatten(1, 3)
245
+ return sinusoid_table
246
+ else:
247
+ sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
248
+ sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
249
+ sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
250
+ return torch.tensor(sinusoid_table, dtype=torch.float, requires_grad=False).unsqueeze(0)
251
+
252
+
253
+ def get_sinusoid_encoding_table2(n_position=784, d_hid=1024, cur_frame=8, ckpt_num_frame=4, pre_n_position=784):
254
+ ''' Sinusoid position encoding table '''
255
+ # TODO: make it with torch instead of numpy
256
+ def get_position_angle_vec(position):
257
+ return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
258
+
259
+ # generate checkpoint position embedding
260
+ sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(pre_n_position)])
261
+ sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
262
+ sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
263
+ sinusoid_table = torch.tensor(sinusoid_table, dtype=torch.float, requires_grad=False).unsqueeze(0)
264
+
265
+ print(f"n_position: {n_position}")
266
+ print(f"pre_n_position: {pre_n_position}")
267
+
268
+ if n_position != pre_n_position:
269
+ T = ckpt_num_frame # checkpoint frame
270
+ P = 14 # checkpoint size
271
+ C = d_hid
272
+ new_P = int((n_position // cur_frame) ** 0.5) # testing size
273
+ print(f'Pretraining uses 14x14, but current version is {new_P}x{new_P}')
274
+ print(f'Interpolate the position embedding')
275
+ sinusoid_table = sinusoid_table.reshape(-1, T, P, P, C)
276
+ sinusoid_table = sinusoid_table.reshape(-1, P, P, C).permute(0, 3, 1, 2)
277
+ sinusoid_table = torch.nn.functional.interpolate(
278
+ sinusoid_table, size=(new_P, new_P), mode='bicubic', align_corners=False)
279
+ # BT, C, H, W -> BT, H, W, C -> B, T, H, W, C
280
+ sinusoid_table = sinusoid_table.permute(0, 2, 3, 1).reshape(-1, T, new_P, new_P, C)
281
+ sinusoid_table = sinusoid_table.flatten(1, 3) # B, THW, C
282
+
283
+ if cur_frame != ckpt_num_frame:
284
+ print(f'Pretraining uses 4 frames, but current frame is {cur_frame}')
285
+ print(f'Interpolate the position embedding')
286
+ T = ckpt_num_frame # checkpoint frame
287
+ new_T = cur_frame # testing frame
288
+ # interpolate
289
+ P = int((n_position // cur_frame) ** 0.5) # testing size
290
+ C = d_hid
291
+ sinusoid_table = sinusoid_table.reshape(-1, T, P, P, C)
292
+ sinusoid_table = sinusoid_table.permute(0, 2, 3, 4, 1).reshape(-1, C, T) # BHW, C, T
293
+ sinusoid_table = torch.nn.functional.interpolate(sinusoid_table, size=new_T, mode='linear')
294
+ sinusoid_table = sinusoid_table.reshape(1, P, P, C, new_T).permute(0, 4, 1, 2, 3) # B, T, H, W, C
295
+ sinusoid_table = sinusoid_table.flatten(1, 3) # B, THW, C
296
+
297
+ return sinusoid_table
298
+
299
+
300
+ class PretrainVisionTransformerEncoder(nn.Module):
301
+ """ Vision Transformer with support for patch or hybrid CNN input stage
302
+ """
303
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, depth=12,
304
+ num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
305
+ drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None, num_frames=8, tubelet_size=1,
306
+ use_learnable_pos_emb=False,
307
+ use_checkpoint=False, checkpoint_num=0,
308
+ ckpt_num_frame=-1, with_ln=True, return_index=-1
309
+ ):
310
+ super().__init__()
311
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
312
+ self.patch_embed = PatchEmbed(
313
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
314
+ num_frames=num_frames, tubelet_size=tubelet_size
315
+ )
316
+ num_patches = self.patch_embed.num_patches
317
+ self.depth = depth + return_index + 1
318
+ self.use_checkpoint = use_checkpoint
319
+ self.checkpoint_num = checkpoint_num
320
+ logger.info(f"Use checkpoint: {use_checkpoint}")
321
+ logger.info(f"Checkpoint number: {checkpoint_num}")
322
+ logger.info(f"Real runing depth: {self.depth}")
323
+
324
+ # TODO: Add the cls token
325
+ if use_learnable_pos_emb:
326
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
327
+ self.img_pos_embed = nn.Parameter(torch.zeros(1, num_patches//(num_frames//tubelet_size) + 1, embed_dim))
328
+ else:
329
+ # sine-cosine positional embeddings
330
+ if img_size != 224:
331
+ self.pos_embed = get_sinusoid_encoding_table2(num_patches, embed_dim, ckpt_num_frame=ckpt_num_frame, cur_frame=num_frames//tubelet_size)
332
+ self.img_pos_embed = get_sinusoid_encoding_table2(num_patches//(num_frames//tubelet_size), embed_dim, cur_frame=1, ckpt_num_frame=1, pre_n_position=14*14)
333
+ else:
334
+ self.pos_embed = get_sinusoid_encoding_table(num_patches, embed_dim, ckpt_num_frame=ckpt_num_frame, cur_frame=num_frames//tubelet_size)
335
+ self.img_pos_embed = get_sinusoid_encoding_table(num_patches//(num_frames//tubelet_size), embed_dim)
336
+
337
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
338
+ self.blocks = nn.ModuleList([
339
+ Block(
340
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
341
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
342
+ init_values=init_values)
343
+ for i in range(self.depth)])
344
+
345
+ if with_ln:
346
+ self.norm = norm_layer(embed_dim)
347
+ else:
348
+ self.norm = nn.Identity()
349
+
350
+ if use_learnable_pos_emb:
351
+ trunc_normal_(self.pos_embed, std=.02)
352
+
353
+ @torch.jit.ignore
354
+ def no_weight_decay(self):
355
+ return {'pos_embed', 'cls_token'}
356
+
357
+ def forward_features(self, x, use_image=False):
358
+ x = self.patch_embed(x)
359
+
360
+ if use_image:
361
+ x = x + self.img_pos_embed.type_as(x).to(x.device).clone().detach()
362
+ else:
363
+ x = x + self.pos_embed.type_as(x).to(x.device).clone().detach()
364
+
365
+ B, _, C = x.shape
366
+ x_vis = x
367
+
368
+ for idx, blk in enumerate(self.blocks):
369
+ if self.use_checkpoint and idx < self.checkpoint_num:
370
+ x_vis = checkpoint.checkpoint(blk, x_vis)
371
+ else:
372
+ x_vis = blk(x_vis)
373
+
374
+ # with ln ot not
375
+ x_vis = self.norm(x_vis)
376
+ return x_vis
377
+
378
+ def forward(self, x, use_image=False):
379
+ x_vis = self.forward_features(x, use_image)
380
+ return x_vis
381
+
382
+
383
+ class PretrainVisionTransformer(nn.Module):
384
+ """ Vision Transformer with support for patch or hybrid CNN input stage
385
+ """
386
+ def __init__(self,
387
+ img_size=224,
388
+ patch_size=16,
389
+ encoder_in_chans=3,
390
+ encoder_embed_dim=768,
391
+ encoder_depth=12,
392
+ encoder_num_heads=12,
393
+ mlp_ratio=4.,
394
+ qkv_bias=True,
395
+ qk_scale=None,
396
+ drop_rate=0.,
397
+ attn_drop_rate=0.,
398
+ drop_path_rate=0.,
399
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
400
+ init_values=0.,
401
+ use_learnable_pos_emb=False,
402
+ num_frames=8,
403
+ tubelet_size=1,
404
+ use_checkpoint=False,
405
+ checkpoint_num=0,
406
+ ckpt_num_frame=4, # the pretrained model uses 4 frames
407
+ return_index=-1,
408
+ with_ln=False
409
+ ):
410
+ super().__init__()
411
+
412
+ self.encoder = PretrainVisionTransformerEncoder(
413
+ img_size=img_size,
414
+ patch_size=patch_size,
415
+ in_chans=encoder_in_chans,
416
+ embed_dim=encoder_embed_dim,
417
+ depth=encoder_depth,
418
+ num_heads=encoder_num_heads,
419
+ mlp_ratio=mlp_ratio,
420
+ qkv_bias=qkv_bias,
421
+ qk_scale=qk_scale,
422
+ drop_rate=drop_rate,
423
+ attn_drop_rate=attn_drop_rate,
424
+ drop_path_rate=drop_path_rate,
425
+ norm_layer=norm_layer,
426
+ init_values=init_values,
427
+ num_frames=num_frames,
428
+ tubelet_size=tubelet_size,
429
+ use_learnable_pos_emb=use_learnable_pos_emb,
430
+ use_checkpoint=use_checkpoint,
431
+ checkpoint_num=checkpoint_num,
432
+ ckpt_num_frame=ckpt_num_frame,
433
+ with_ln=with_ln,
434
+ return_index=return_index
435
+ )
436
+ logger.info(f'With LN: {with_ln}')
437
+ logger.info(f'Total {encoder_depth} layer')
438
+ logger.info(f'Return {encoder_depth+return_index+1}-th layer')
439
+
440
+ self.apply(self._init_weights)
441
+
442
+ def _init_weights(self, m):
443
+ if isinstance(m, nn.Linear):
444
+ nn.init.xavier_uniform_(m.weight)
445
+ if isinstance(m, nn.Linear) and m.bias is not None:
446
+ nn.init.constant_(m.bias, 0)
447
+ elif isinstance(m, nn.LayerNorm):
448
+ nn.init.constant_(m.bias, 0)
449
+ nn.init.constant_(m.weight, 1.0)
450
+
451
+ @torch.jit.ignore
452
+ def no_weight_decay(self):
453
+ return {'pos_embed', 'cls_token', 'clip_pos_embed'}
454
+
455
+ def forward(self, x, use_image=False):
456
+ T = x.shape[2]
457
+ x_vis = self.encoder(x, use_image) # [B, N_vis, C_e]
458
+ B, TL, C = x_vis.shape
459
+ x_vis = x_vis.view(B, T, TL // T, C)
460
+
461
+ return x_vis
462
+
463
+
464
+ def build_vit(config):
465
+ model = PretrainVisionTransformer(
466
+ img_size=config.vision_encoder.img_size,
467
+ patch_size=config.vision_encoder.patch_size,
468
+ encoder_embed_dim=config.vision_encoder.encoder_embed_dim,
469
+ encoder_depth=config.vision_encoder.encoder_depth,
470
+ encoder_num_heads=config.vision_encoder.encoder_num_heads,
471
+ drop_path_rate=config.vision_encoder.drop_path_rate,
472
+ num_frames=config.vision_encoder.num_frames,
473
+ tubelet_size=config.vision_encoder.tubelet_size,
474
+ use_checkpoint=config.vision_encoder.use_checkpoint,
475
+ checkpoint_num=config.vision_encoder.checkpoint_num,
476
+ return_index=config.vision_encoder.get('return_index', -1),
477
+ with_ln=config.vision_encoder.get('with_ln', False),
478
+ )
479
+ model.default_cfg = _cfg()
480
+ if config.vision_encoder.pretrained:
481
+ logger.info(f"Loading pretrained weights from {config.vision_encoder.pretrained}")
482
+ state_dict = torch.load(config.vision_encoder.pretrained, map_location='cpu')
483
+ model.load_state_dict(state_dict, strict=False)
484
+ else:
485
+ logger.info("No pretrained weights!!!")
486
+ return model
487
+
special_tokens_map.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<s>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "</s>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": "<unk>",
17
+ "unk_token": {
18
+ "content": "<unk>",
19
+ "lstrip": false,
20
+ "normalized": false,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ }
24
+ }
third_party/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ import logging
2
+ logger = logging.getLogger(__name__)
third_party/cgdetr/cg_detr/__init__.py ADDED
File without changes
third_party/cgdetr/cg_detr/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (168 Bytes). View file
 
third_party/cgdetr/cg_detr/__pycache__/attention.cpython-310.pyc ADDED
Binary file (15.2 kB). View file
 
third_party/cgdetr/cg_detr/__pycache__/crossattention.cpython-310.pyc ADDED
Binary file (15.4 kB). View file
 
third_party/cgdetr/cg_detr/__pycache__/matcher.cpython-310.pyc ADDED
Binary file (4.65 kB). View file
 
third_party/cgdetr/cg_detr/__pycache__/misc.cpython-310.pyc ADDED
Binary file (714 Bytes). View file
 
third_party/cgdetr/cg_detr/__pycache__/model.cpython-310.pyc ADDED
Binary file (32.6 kB). View file
 
third_party/cgdetr/cg_detr/__pycache__/position_encoding.cpython-310.pyc ADDED
Binary file (4.33 kB). View file
 
third_party/cgdetr/cg_detr/__pycache__/span_utils.cpython-310.pyc ADDED
Binary file (4.19 kB). View file
 
third_party/cgdetr/cg_detr/__pycache__/transformer.cpython-310.pyc ADDED
Binary file (22.7 kB). View file
 
third_party/cgdetr/cg_detr/attention.py ADDED
@@ -0,0 +1,394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # DAB-DETR
3
+ # Copyright (c) 2022 IDEA. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------
6
+ # Modified from Conditional DETR (https://github.com/Atten4Vis/ConditionalDETR)
7
+ # Copyright (c) 2021 Microsoft. All Rights Reserved.
8
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
9
+ # ------------------------------------------------------------------------
10
+ # Modified from DETR (https://github.com/facebookresearch/detr)
11
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
12
+ # ------------------------------------------------------------------------
13
+ # Modified from codes in torch.nn
14
+ # ------------------------------------------------------------------------
15
+
16
+ """
17
+ MultiheadAttention that support query, key, and value to have different dimensions.
18
+ Query, key, and value projections are removed.
19
+ Mostly copy-paste from https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/activation.py#L873
20
+ and https://github.com/pytorch/pytorch/blob/master/torch/nn/functional.py#L4837
21
+ """
22
+
23
+ import copy
24
+ from typing import Optional, List
25
+
26
+ import torch
27
+ import torch.nn.functional as F
28
+ from torch import nn, Tensor
29
+
30
+ import warnings
31
+ from typing import Tuple, Optional
32
+
33
+ import torch
34
+ from torch import Tensor
35
+ from torch.nn.modules.linear import Linear
36
+ from torch.nn.init import xavier_uniform_
37
+ from torch.nn.init import constant_
38
+ from torch.nn.init import xavier_normal_
39
+ from torch.nn.parameter import Parameter
40
+ from torch.nn.modules.module import Module
41
+ from torch.nn import functional as F
42
+
43
+ import warnings
44
+ import math
45
+
46
+ from torch._C import _infer_size, _add_docstr
47
+ from torch.nn import _reduction as _Reduction
48
+ from torch.nn.modules import utils
49
+ from torch.nn.modules.utils import _single, _pair, _triple, _list_with_default
50
+ from torch.nn import grad
51
+ from torch import _VF
52
+ from torch._jit_internal import boolean_dispatch, List, Optional, _overload, Tuple
53
+ try:
54
+ from torch.overrides import has_torch_function, handle_torch_function
55
+ except:
56
+ from torch._overrides import has_torch_function, handle_torch_function
57
+ Tensor = torch.Tensor
58
+
59
+ from torch.nn.functional import linear, pad, softmax, dropout
60
+
61
+ class MultiheadAttention(Module):
62
+ r"""Allows the model to jointly attend to information
63
+ from different representation subspaces.
64
+ See reference: Attention Is All You Need
65
+ .. math::
66
+ \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
67
+ \text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)
68
+ Args:
69
+ embed_dim: total dimension of the model.
70
+ num_heads: parallel attention heads.
71
+ dropout: a Dropout layer on attn_output_weights. Default: 0.0.
72
+ bias: add bias as module parameter. Default: True.
73
+ add_bias_kv: add bias to the key and value sequences at dim=0.
74
+ add_zero_attn: add a new batch of zeros to the key and
75
+ value sequences at dim=1.
76
+ kdim: total number of features in key. Default: None.
77
+ vdim: total number of features in value. Default: None.
78
+ Note: if kdim and vdim are None, they will be set to embed_dim such that
79
+ query, key, and value have the same number of features.
80
+ Examples::
81
+ >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
82
+ >>> attn_output, attn_output_weights = multihead_attn(query, key, value)
83
+ """
84
+ bias_k: Optional[torch.Tensor]
85
+ bias_v: Optional[torch.Tensor]
86
+
87
+ def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None):
88
+ super(MultiheadAttention, self).__init__()
89
+ self.embed_dim = embed_dim
90
+ self.kdim = kdim if kdim is not None else embed_dim
91
+ self.vdim = vdim if vdim is not None else embed_dim
92
+ self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
93
+
94
+ self.num_heads = num_heads
95
+ self.dropout = dropout
96
+ self.head_dim = embed_dim // num_heads
97
+ assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
98
+
99
+ vdim = vdim if vdim is not None else embed_dim
100
+ self.out_proj = Linear(vdim , vdim)
101
+
102
+ self.in_proj_bias = None
103
+ self.in_proj_weight = None
104
+ self.bias_k = self.bias_v = None
105
+ self.q_proj_weight = None
106
+ self.k_proj_weight = None
107
+ self.v_proj_weight = None
108
+
109
+ self.add_zero_attn = add_zero_attn
110
+
111
+ self._reset_parameters()
112
+
113
+ def _reset_parameters(self):
114
+ constant_(self.out_proj.bias, 0.)
115
+
116
+ def __setstate__(self, state):
117
+ # Support loading old MultiheadAttention checkpoints generated by v1.1.0
118
+ if '_qkv_same_embed_dim' not in state:
119
+ state['_qkv_same_embed_dim'] = True
120
+
121
+ super(MultiheadAttention, self).__setstate__(state)
122
+
123
+ def forward(self, query, key, value, key_padding_mask=None,
124
+ need_weights=True, attn_mask=None):
125
+ # type: (Tensor, Tensor, Tensor, Optional[Tensor], bool, Optional[Tensor]) -> Tuple[Tensor, Optional[Tensor]]
126
+ r"""
127
+ Args:
128
+ query, key, value: map a query and a set of key-value pairs to an output.
129
+ See "Attention Is All You Need" for more details.
130
+ key_padding_mask: if provided, specified padding elements in the key will
131
+ be ignored by the attention. When given a binary mask and a value is True,
132
+ the corresponding value on the attention layer will be ignored. When given
133
+ a byte mask and a value is non-zero, the corresponding value on the attention
134
+ layer will be ignored
135
+ need_weights: output attn_output_weights.
136
+ attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
137
+ the batches while a 3D mask allows to specify a different mask for the entries of each batch.
138
+ Shape:
139
+ - Inputs:
140
+ - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
141
+ the embedding dimension.
142
+ - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
143
+ the embedding dimension.
144
+ - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
145
+ the embedding dimension.
146
+ - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
147
+ If a ByteTensor is provided, the non-zero positions will be ignored while the position
148
+ with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the
149
+ value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
150
+ - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
151
+ 3D mask :math:`(N*\text{num_heads}, L, S)` where N is the batch size, L is the target sequence length,
152
+ S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked
153
+ positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
154
+ while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
155
+ is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
156
+ is provided, it will be added to the attention weight.
157
+ - Outputs:
158
+ - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
159
+ E is the embedding dimension.
160
+ - attn_output_weights: :math:`(N, L, S)` where N is the batch size,
161
+ L is the target sequence length, S is the source sequence length.
162
+ """
163
+ if not self._qkv_same_embed_dim:
164
+ return multi_head_attention_forward(
165
+ query, key, value, self.embed_dim, self.num_heads,
166
+ self.in_proj_weight, self.in_proj_bias,
167
+ self.bias_k, self.bias_v, self.add_zero_attn,
168
+ self.dropout, self.out_proj.weight, self.out_proj.bias,
169
+ training=self.training,
170
+ key_padding_mask=key_padding_mask, need_weights=need_weights,
171
+ attn_mask=attn_mask, use_separate_proj_weight=True,
172
+ q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
173
+ v_proj_weight=self.v_proj_weight, out_dim=self.vdim)
174
+ else:
175
+ return multi_head_attention_forward(
176
+ query, key, value, self.embed_dim, self.num_heads,
177
+ self.in_proj_weight, self.in_proj_bias,
178
+ self.bias_k, self.bias_v, self.add_zero_attn,
179
+ self.dropout, self.out_proj.weight, self.out_proj.bias,
180
+ training=self.training,
181
+ key_padding_mask=key_padding_mask, need_weights=need_weights,
182
+ attn_mask=attn_mask, out_dim=self.vdim)
183
+
184
+
185
+ def multi_head_attention_forward(query: Tensor,
186
+ key: Tensor,
187
+ value: Tensor,
188
+ embed_dim_to_check: int,
189
+ num_heads: int,
190
+ in_proj_weight: Tensor,
191
+ in_proj_bias: Tensor,
192
+ bias_k: Optional[Tensor],
193
+ bias_v: Optional[Tensor],
194
+ add_zero_attn: bool,
195
+ dropout_p: float,
196
+ out_proj_weight: Tensor,
197
+ out_proj_bias: Tensor,
198
+ training: bool = True,
199
+ key_padding_mask: Optional[Tensor] = None,
200
+ need_weights: bool = True,
201
+ attn_mask: Optional[Tensor] = None,
202
+ use_separate_proj_weight: bool = False,
203
+ q_proj_weight: Optional[Tensor] = None,
204
+ k_proj_weight: Optional[Tensor] = None,
205
+ v_proj_weight: Optional[Tensor] = None,
206
+ static_k: Optional[Tensor] = None,
207
+ static_v: Optional[Tensor] = None,
208
+ out_dim: Optional[Tensor] = None
209
+ ) -> Tuple[Tensor, Optional[Tensor]]:
210
+ r"""
211
+ Args:
212
+ query, key, value: map a query and a set of key-value pairs to an output.
213
+ See "Attention Is All You Need" for more details.
214
+ embed_dim_to_check: total dimension of the model.
215
+ num_heads: parallel attention heads.
216
+ in_proj_weight, in_proj_bias: input projection weight and bias.
217
+ bias_k, bias_v: bias of the key and value sequences to be added at dim=0.
218
+ add_zero_attn: add a new batch of zeros to the key and
219
+ value sequences at dim=1.
220
+ dropout_p: probability of an element to be zeroed.
221
+ out_proj_weight, out_proj_bias: the output projection weight and bias.
222
+ training: apply dropout if is ``True``.
223
+ key_padding_mask: if provided, specified padding elements in the key will
224
+ be ignored by the attention. This is an binary mask. When the value is True,
225
+ the corresponding value on the attention layer will be filled with -inf.
226
+ need_weights: output attn_output_weights.
227
+ attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
228
+ the batches while a 3D mask allows to specify a different mask for the entries of each batch.
229
+ use_separate_proj_weight: the function accept the proj. weights for query, key,
230
+ and value in different forms. If false, in_proj_weight will be used, which is
231
+ a combination of q_proj_weight, k_proj_weight, v_proj_weight.
232
+ q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias.
233
+ static_k, static_v: static key and value used for attention operators.
234
+ Shape:
235
+ Inputs:
236
+ - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
237
+ the embedding dimension.
238
+ - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
239
+ the embedding dimension.
240
+ - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
241
+ the embedding dimension.
242
+ - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
243
+ If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions
244
+ will be unchanged. If a BoolTensor is provided, the positions with the
245
+ value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
246
+ - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
247
+ 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
248
+ S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked
249
+ positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
250
+ while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
251
+ are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
252
+ is provided, it will be added to the attention weight.
253
+ - static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
254
+ N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
255
+ - static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
256
+ N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
257
+ Outputs:
258
+ - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
259
+ E is the embedding dimension.
260
+ - attn_output_weights: :math:`(N, L, S)` where N is the batch size,
261
+ L is the target sequence length, S is the source sequence length.
262
+ """
263
+ if not torch.jit.is_scripting():
264
+ tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v,
265
+ out_proj_weight, out_proj_bias)
266
+ if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):
267
+ return handle_torch_function(
268
+ multi_head_attention_forward, tens_ops, query, key, value,
269
+ embed_dim_to_check, num_heads, in_proj_weight, in_proj_bias,
270
+ bias_k, bias_v, add_zero_attn, dropout_p, out_proj_weight,
271
+ out_proj_bias, training=training, key_padding_mask=key_padding_mask,
272
+ need_weights=need_weights, attn_mask=attn_mask,
273
+ use_separate_proj_weight=use_separate_proj_weight,
274
+ q_proj_weight=q_proj_weight, k_proj_weight=k_proj_weight,
275
+ v_proj_weight=v_proj_weight, static_k=static_k, static_v=static_v)
276
+ tgt_len, bsz, embed_dim = query.size()
277
+ assert embed_dim == embed_dim_to_check
278
+ # allow MHA to have different sizes for the feature dimension
279
+ assert key.size(0) == value.size(0) and key.size(1) == value.size(1)
280
+
281
+ head_dim = embed_dim // num_heads
282
+ v_head_dim = out_dim // num_heads
283
+ assert head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"
284
+ scaling = float(head_dim) ** -0.5
285
+
286
+ q = query * scaling
287
+ k = key
288
+ v = value
289
+
290
+ if attn_mask is not None:
291
+ assert attn_mask.dtype == torch.float32 or attn_mask.dtype == torch.float64 or \
292
+ attn_mask.dtype == torch.float16 or attn_mask.dtype == torch.uint8 or attn_mask.dtype == torch.bool, \
293
+ 'Only float, byte, and bool types are supported for attn_mask, not {}'.format(attn_mask.dtype)
294
+ if attn_mask.dtype == torch.uint8:
295
+ warnings.warn("Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
296
+ attn_mask = attn_mask.to(torch.bool)
297
+
298
+ if attn_mask.dim() == 2:
299
+ attn_mask = attn_mask.unsqueeze(0)
300
+ if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
301
+ raise RuntimeError('The size of the 2D attn_mask is not correct.')
302
+ elif attn_mask.dim() == 3:
303
+ if list(attn_mask.size()) != [bsz * num_heads, query.size(0), key.size(0)]:
304
+ raise RuntimeError('The size of the 3D attn_mask is not correct.')
305
+ else:
306
+ raise RuntimeError("attn_mask's dimension {} is not supported".format(attn_mask.dim()))
307
+ # attn_mask's dim is 3 now.
308
+
309
+ # convert ByteTensor key_padding_mask to bool
310
+ if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
311
+ warnings.warn("Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
312
+ key_padding_mask = key_padding_mask.to(torch.bool)
313
+
314
+ if bias_k is not None and bias_v is not None:
315
+ if static_k is None and static_v is None:
316
+ k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
317
+ v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
318
+ if attn_mask is not None:
319
+ attn_mask = pad(attn_mask, (0, 1))
320
+ if key_padding_mask is not None:
321
+ key_padding_mask = pad(key_padding_mask, (0, 1))
322
+ else:
323
+ assert static_k is None, "bias cannot be added to static key."
324
+ assert static_v is None, "bias cannot be added to static value."
325
+ else:
326
+ assert bias_k is None
327
+ assert bias_v is None
328
+
329
+ q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
330
+ if k is not None:
331
+ k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
332
+ if v is not None:
333
+ v = v.contiguous().view(-1, bsz * num_heads, v_head_dim).transpose(0, 1)
334
+
335
+ if static_k is not None:
336
+ assert static_k.size(0) == bsz * num_heads
337
+ assert static_k.size(2) == head_dim
338
+ k = static_k
339
+
340
+ if static_v is not None:
341
+ assert static_v.size(0) == bsz * num_heads
342
+ assert static_v.size(2) == v_head_dim
343
+ v = static_v
344
+
345
+ src_len = k.size(1)
346
+
347
+ if key_padding_mask is not None:
348
+ assert key_padding_mask.size(0) == bsz
349
+ assert key_padding_mask.size(1) == src_len
350
+
351
+ if add_zero_attn:
352
+ src_len += 1
353
+ k = torch.cat([k, torch.zeros((k.size(0), 1) + k.size()[2:], dtype=k.dtype, device=k.device)], dim=1)
354
+ v = torch.cat([v, torch.zeros((v.size(0), 1) + v.size()[2:], dtype=v.dtype, device=v.device)], dim=1)
355
+ if attn_mask is not None:
356
+ attn_mask = pad(attn_mask, (0, 1))
357
+ if key_padding_mask is not None:
358
+ key_padding_mask = pad(key_padding_mask, (0, 1))
359
+
360
+ attn_output_weights = torch.bmm(q, k.transpose(1, 2))
361
+ assert list(attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len]
362
+
363
+ if attn_mask is not None:
364
+ if attn_mask.dtype == torch.bool:
365
+ attn_output_weights.masked_fill_(attn_mask, float('-inf'))
366
+ else:
367
+ attn_output_weights += attn_mask
368
+
369
+
370
+ if key_padding_mask is not None:
371
+ attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
372
+ attn_output_weights = attn_output_weights.masked_fill(
373
+ key_padding_mask.unsqueeze(1).unsqueeze(2),
374
+ float('-inf'),
375
+ )
376
+ attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, src_len)
377
+
378
+ # attn_output_weights = softmax(
379
+ # attn_output_weights, dim=-1)
380
+ attn_output_weights = softmax(
381
+ attn_output_weights - attn_output_weights.max(dim=-1, keepdim=True)[0], dim=-1)
382
+ attn_output_weights = dropout(attn_output_weights, p=dropout_p, training=training)
383
+
384
+ attn_output = torch.bmm(attn_output_weights, v)
385
+ assert list(attn_output.size()) == [bsz * num_heads, tgt_len, v_head_dim]
386
+ attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, out_dim)
387
+ attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
388
+
389
+ if need_weights:
390
+ # average attention weights over heads
391
+ attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
392
+ return attn_output, attn_output_weights.sum(dim=1) / num_heads
393
+ else:
394
+ return attn_output, None
third_party/cgdetr/cg_detr/config.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import torch
4
+ import argparse
5
+
6
+ from third_party.cgdetr.utils.basic_utils import mkdirp, load_json, save_json, make_zipfile, dict_to_markdown
7
+ import shutil
8
+
9
+ class BaseOptions(object):
10
+ saved_option_filename = "opt.json"
11
+ ckpt_filename = "model.ckpt"
12
+ tensorboard_log_dir = "tensorboard_log"
13
+ train_log_filename = "train.log.txt"
14
+ eval_log_filename = "eval.log.txt"
15
+
16
+ def __init__(self):
17
+ self.parser = None
18
+ self.initialized = False
19
+ self.opt = None
20
+
21
+ def initialize(self):
22
+ self.initialized = True
23
+ parser = argparse.ArgumentParser()
24
+ # parser.add_argument("--dset_name", type=str, choices=["hl", 'charadesSTA', ])
25
+ # parser.add_argument("--dset_domain", type=str,
26
+ # help="Domain to train for tvsum dataset. (Only used for tvsum and youtube-hl)")
27
+
28
+ parser.add_argument("--eval_split_name", type=str, default="val",
29
+ help="should match keys in video_duration_idx_path, must set for VCMR")
30
+ parser.add_argument("--debug", action="store_true",
31
+ help="debug (fast) mode, break all loops, do not load all data into memory.")
32
+ parser.add_argument("--data_ratio", type=float, default=1.0,
33
+ help="how many training and eval data to use. 1.0: use all, 0.1: use 10%."
34
+ "Use small portion for debug purposes. Note this is different from --debug, "
35
+ "which works by breaking the loops, typically they are not used together.")
36
+ parser.add_argument("--results_root", type=str, default="results")
37
+ parser.add_argument("--exp_id", type=str, default=None, help="id of this run, required at training")
38
+ parser.add_argument("--seed", type=int, default=2018, help="random seed")
39
+ # parser.add_argument("--device", type=int, default=0, help="0 cuda, -1 cpu")
40
+ parser.add_argument("--num_workers", type=int, default=0,
41
+ help="num subprocesses used to load the data, 0: use main process")
42
+ parser.add_argument("--no_pin_memory", action="store_true",
43
+ help="Don't use pin_memory=True for dataloader. "
44
+ "ref: https://discuss.pytorch.org/t/should-we-set-non-blocking-to-true/38234/4")
45
+
46
+ # training config
47
+ # parser.add_argument("--lr", type=float, default=2e-4, help="learning rate")
48
+ # parser.add_argument("--lr_drop", type=int, default=800, help="drop learning rate to 1/10 every lr_drop epochs")
49
+ # parser.add_argument("--wd", type=float, default=1e-4, help="weight decay")
50
+ parser.add_argument("--n_epoch", type=int, default=200, help="number of epochs to run")
51
+ parser.add_argument("--max_es_cnt", type=int, default=200,
52
+ help="number of epochs to early stop, use -1 to disable early stop")
53
+ # parser.add_argument("--bsz", type=int, default=32, help="mini-batch size")
54
+ # parser.add_argument("--eval_bsz", type=int, default=100,
55
+ # help="mini-batch size at inference, for query")
56
+ parser.add_argument("--eval_epoch", type=int, default=5,help="inference epoch")
57
+ parser.add_argument("--grad_clip", type=float, default=0.1, help="perform gradient clip, -1: disable")
58
+ parser.add_argument("--eval_untrained", action="store_true", help="Evaluate on un-trained model")
59
+ parser.add_argument("--resume", type=str, default=None,
60
+ help="checkpoint path to resume or evaluate, without --resume_all this only load weights")
61
+ parser.add_argument("--resume_all", action="store_true",
62
+ help="if --resume_all, load optimizer/scheduler/epoch as well")
63
+ parser.add_argument("--start_epoch", type=int, default=None,
64
+ help="if None, will be set automatically when using --resume_all")
65
+
66
+ # Data config
67
+ parser.add_argument("--max_q_l", type=int, default=-1)
68
+ parser.add_argument("--max_v_l", type=int, default=-1)
69
+ parser.add_argument("--clip_length", type=float, default=2)
70
+ parser.add_argument("--max_windows", type=int, default=5)
71
+
72
+ parser.add_argument("--train_path", type=str, default=None)
73
+ parser.add_argument("--eval_path", type=str, default=None,
74
+ help="Evaluating during training, for Dev set. If None, will only do training, ")
75
+ parser.add_argument("--no_norm_vfeat", action="store_true", help="Do not do normalize video feat")
76
+ parser.add_argument("--no_norm_tfeat", action="store_true", help="Do not do normalize text feat")
77
+ parser.add_argument("--v_feat_dirs", type=str, nargs="+",
78
+ help="video feature dirs. If more than one, will concat their features. "
79
+ "Note that sub ctx features are also accepted here.")
80
+ parser.add_argument("--t_feat_dir", type=str, help="text/query feature dir")
81
+ # parser.add_argument("--a_feat_dir", type=str, help="audio feature dir")
82
+ parser.add_argument("--v_feat_dim", type=int, default=770, help="video feature dim")
83
+ parser.add_argument("--t_feat_dim", type=int, default=4096, help="text/query feature dim")
84
+ # parser.add_argument("--a_feat_dim", type=int, help="audio feature dim")
85
+ parser.add_argument("--ctx_mode", type=str, default="video_tef")
86
+
87
+ # Model config
88
+ parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine', 'learned'),
89
+ help="Type of positional embedding to use on top of the image features")
90
+ # * Transformer
91
+ parser.add_argument('--enc_layers', default=3, type=int,
92
+ help="Number of encoding layers in the transformer")
93
+ parser.add_argument('--dec_layers', default=3, type=int,
94
+ help="Number of decoding layers in the transformer")
95
+ parser.add_argument('--t2v_layers', default=2, type=int,
96
+ help="Number of decoding layers in the transformer")
97
+ parser.add_argument('--sent_layers', default=1, type=int,
98
+ help="Number of decoding layers in the transformer")
99
+ parser.add_argument('--moment_layers', default=1, type=int,
100
+ help="Number of decoding layers in the transformer")
101
+ parser.add_argument('--dummy_layers', default=2, type=int,
102
+ help="Number of encoding layers in the transformer")
103
+ parser.add_argument('--dim_feedforward', default=1024, type=int,
104
+ help="Intermediate size of the feedforward layers in the transformer blocks")
105
+ parser.add_argument('--hidden_dim', default=256, type=int,
106
+ help="Size of the embeddings (dimension of the transformer)")
107
+ parser.add_argument('--input_dropout', default=0.5, type=float,
108
+ help="Dropout applied in input")
109
+ parser.add_argument('--dropout', default=0.1, type=float,
110
+ help="Dropout applied in the transformer")
111
+ parser.add_argument("--txt_drop_ratio", default=0, type=float,
112
+ help="drop txt_drop_ratio tokens from text input. 0.1=10%")
113
+ parser.add_argument("--use_txt_pos", action="store_true", help="use position_embedding for text as well.")
114
+ parser.add_argument('--nheads', default=8, type=int,
115
+ help="Number of attention heads inside the transformer's attentions")
116
+ parser.add_argument('--num_queries', default=10, type=int,
117
+ help="Number of query slots")
118
+ parser.add_argument('--num_dummies', default=45, type=int,
119
+ help="Number of dummy tokens")
120
+ parser.add_argument('--total_prompts', default=10, type=int,
121
+ help="Number of query slots")
122
+ parser.add_argument('--num_prompts', default=1, type=int,
123
+ help="Number of dummy tokens")
124
+ parser.add_argument('--pre_norm', action='store_true')
125
+ # other model configs
126
+ parser.add_argument("--n_input_proj", type=int, default=2, help="#layers to encoder input")
127
+ parser.add_argument("--contrastive_hdim", type=int, default=64, help="dim for contrastive embeddings")
128
+ parser.add_argument("--temperature", type=float, default=0.07, help="temperature nce contrastive_align_loss")
129
+ # Loss
130
+
131
+ parser.add_argument("--saliency_margin", type=float, default=0.2)
132
+ parser.add_argument('--no_aux_loss', dest='aux_loss', action='store_false',
133
+ help="Disables auxiliary decoding losses (loss at each layer)")
134
+ parser.add_argument("--span_loss_type", default="l1", type=str, choices=['l1', 'ce'],
135
+ help="l1: (center-x, width) regression. ce: (st_idx, ed_idx) classification.")
136
+ parser.add_argument("--contrastive_align_loss", action="store_true",
137
+ help="Disable contrastive_align_loss between matched query spans and the text.")
138
+ # * Matcher
139
+ parser.add_argument('--set_cost_span', default=10, type=float,
140
+ help="L1 span coefficient in the matching cost")
141
+ parser.add_argument('--set_cost_giou', default=1, type=float,
142
+ help="giou span coefficient in the matching cost")
143
+ parser.add_argument('--set_cost_class', default=4, type=float,
144
+ help="Class coefficient in the matching cost")
145
+
146
+ # * Loss coefficients
147
+ parser.add_argument("--lw_saliency", type=float, default=1.,
148
+ help="weight for saliency loss, set to 0 will ignore")
149
+ parser.add_argument("--lw_wattn", type=float, default=1.,
150
+ help="weight for saliency loss, set to 0 will ignore")
151
+ parser.add_argument("--lw_ms_align", type=float, default=1.,
152
+ help="weight for saliency loss, set to 0 will ignore")
153
+ parser.add_argument("--lw_distill", type=float, default=1.,
154
+ help="weight for saliency loss, set to 0 will ignore")
155
+ parser.add_argument('--span_loss_coef', default=10, type=float)
156
+ parser.add_argument('--giou_loss_coef', default=1, type=float)
157
+ parser.add_argument('--label_loss_coef', default=4, type=float)
158
+ parser.add_argument('--eos_coef', default=0.1, type=float,
159
+ help="Relative classification weight of the no-object class")
160
+ parser.add_argument("--contrastive_align_loss_coef", default=0.0, type=float)
161
+
162
+ parser.add_argument("--no_sort_results", action="store_true",
163
+ help="do not sort results, use this for moment query visualization")
164
+ parser.add_argument("--max_before_nms", type=int, default=10)
165
+ parser.add_argument("--max_after_nms", type=int, default=10)
166
+ parser.add_argument("--conf_thd", type=float, default=0.0, help="only keep windows with conf >= conf_thd")
167
+ parser.add_argument("--nms_thd", type=float, default=-1,
168
+ help="additionally use non-maximum suppression "
169
+ "(or non-minimum suppression for distance)"
170
+ "to post-processing the predictions. "
171
+ "-1: do not use nms. [0, 1]")
172
+ self.parser = parser
173
+
174
+ def display_save(self, opt):
175
+ args = vars(opt)
176
+ # Display settings
177
+ print(dict_to_markdown(vars(opt), max_str_len=120))
178
+ # Save settings
179
+ if not isinstance(self, TestOptions):
180
+ option_file_path = os.path.join(opt.results_dir, self.saved_option_filename) # not yaml file indeed
181
+ save_json(args, option_file_path, save_pretty=True)
182
+
183
+ def parse(self, a_feat_dir=None):
184
+ if not self.initialized:
185
+ self.initialize()
186
+ opt = self.parser.parse_args()
187
+
188
+ if opt.debug:
189
+ opt.results_root = os.path.sep.join(opt.results_root.split(os.path.sep)[:-1] + ["debug_results", ])
190
+ opt.num_workers = 0
191
+
192
+ if isinstance(self, TestOptions):
193
+ # modify model_dir to absolute path
194
+ # opt.model_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "results", opt.model_dir)
195
+ opt.model_dir = os.path.dirname(opt.resume)
196
+ if a_feat_dir is not None:
197
+ opt.a_feat_dir = a_feat_dir
198
+ saved_options = load_json(os.path.join(opt.model_dir, self.saved_option_filename))
199
+ for arg in saved_options: # use saved options to overwrite all BaseOptions args.
200
+ if arg not in ["results_root", "num_workers", "nms_thd", "debug", # "max_before_nms", "max_after_nms"
201
+ "max_pred_l", "min_pred_l",
202
+ "resume", "resume_all", "no_sort_results"]:
203
+ setattr(opt, arg, saved_options[arg])
204
+ # opt.no_core_driver = True
205
+ if opt.eval_results_dir is not None:
206
+ opt.results_dir = opt.eval_results_dir
207
+ else:
208
+ if opt.exp_id is None:
209
+ raise ValueError("--exp_id is required for at a training option!")
210
+
211
+ ctx_str = opt.ctx_mode + "_sub" if any(["sub_ctx" in p for p in opt.v_feat_dirs]) else opt.ctx_mode
212
+ opt.results_dir = os.path.join(opt.results_root,
213
+ "-".join([opt.dset_name, ctx_str, opt.exp_id,
214
+ str(opt.enc_layers) + str(opt.dec_layers) + str(opt.t2v_layers) + str(opt.moment_layers) + str(opt.dummy_layers) + str(opt.sent_layers),
215
+ 'ndum_' + str(opt.num_dummies), 'nprom_' + str(opt.num_prompts) + '_' + str(opt.total_prompts)]))
216
+ mkdirp(opt.results_dir)
217
+ save_fns = ['cg_detr/model.py', 'cg_detr/transformer.py']
218
+ for save_fn in save_fns:
219
+ shutil.copyfile(save_fn, os.path.join(opt.results_dir, os.path.basename(save_fn)))
220
+
221
+ # save a copy of current code
222
+ code_dir = os.path.dirname(os.path.realpath(__file__))
223
+ code_zip_filename = os.path.join(opt.results_dir, "code.zip")
224
+ make_zipfile(code_dir, code_zip_filename,
225
+ enclosing_dir="code",
226
+ exclude_dirs_substring="results",
227
+ exclude_dirs=["results", "debug_results", "__pycache__"],
228
+ exclude_extensions=[".pyc", ".ipynb", ".swap"], )
229
+
230
+ self.display_save(opt)
231
+
232
+ opt.ckpt_filepath = os.path.join(opt.results_dir, self.ckpt_filename)
233
+ opt.train_log_filepath = os.path.join(opt.results_dir, self.train_log_filename)
234
+ opt.eval_log_filepath = os.path.join(opt.results_dir, self.eval_log_filename)
235
+ opt.tensorboard_log_dir = os.path.join(opt.results_dir, self.tensorboard_log_dir)
236
+ opt.device = torch.device("cuda" if opt.device >= 0 else "cpu")
237
+ opt.pin_memory = not opt.no_pin_memory
238
+
239
+ opt.use_tef = "tef" in opt.ctx_mode
240
+ opt.use_video = "video" in opt.ctx_mode
241
+ if not opt.use_video:
242
+ opt.v_feat_dim = 0
243
+ if opt.use_tef:
244
+ opt.v_feat_dim += 2
245
+
246
+ self.opt = opt
247
+ return opt
248
+
249
+
250
+ class TestOptions(BaseOptions):
251
+ """add additional options for evaluating"""
252
+
253
+ def initialize(self):
254
+ BaseOptions.initialize(self)
255
+ # also need to specify --eval_split_name
256
+ self.parser.add_argument("--eval_id", type=str, help="evaluation id")
257
+ self.parser.add_argument("--eval_results_dir", type=str, default=None,
258
+ help="dir to save results, if not set, fall back to training results_dir")
259
+ self.parser.add_argument("--model_dir", type=str,
260
+ help="dir contains the model file, will be converted to absolute path afterwards")
261
+
third_party/cgdetr/cg_detr/crossattention.py ADDED
@@ -0,0 +1,396 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # DAB-DETR
3
+ # Copyright (c) 2022 IDEA. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------
6
+ # Modified from Conditional DETR (https://github.com/Atten4Vis/ConditionalDETR)
7
+ # Copyright (c) 2021 Microsoft. All Rights Reserved.
8
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
9
+ # ------------------------------------------------------------------------
10
+ # Modified from DETR (https://github.com/facebookresearch/detr)
11
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
12
+ # ------------------------------------------------------------------------
13
+ # Modified from codes in torch.nn
14
+ # ------------------------------------------------------------------------
15
+
16
+ """
17
+ MultiheadAttention that support query, key, and value to have different dimensions.
18
+ Query, key, and value projections are removed.
19
+ Mostly copy-paste from https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/activation.py#L873
20
+ and https://github.com/pytorch/pytorch/blob/master/torch/nn/functional.py#L4837
21
+ """
22
+
23
+ import copy
24
+ from typing import Optional, List
25
+
26
+ import torch
27
+ import torch.nn.functional as F
28
+ from torch import nn, Tensor
29
+
30
+ import warnings
31
+ from typing import Tuple, Optional
32
+
33
+ import torch
34
+ from torch import Tensor
35
+ from torch.nn.modules.linear import Linear
36
+ from torch.nn.init import xavier_uniform_
37
+ from torch.nn.init import constant_
38
+ from torch.nn.init import xavier_normal_
39
+ from torch.nn.parameter import Parameter
40
+ from torch.nn.modules.module import Module
41
+ from torch.nn import functional as F
42
+
43
+ import warnings
44
+ import math
45
+
46
+ from torch._C import _infer_size, _add_docstr
47
+ from torch.nn import _reduction as _Reduction
48
+ from torch.nn.modules import utils
49
+ from torch.nn.modules.utils import _single, _pair, _triple, _list_with_default
50
+ from torch.nn import grad
51
+ from torch import _VF
52
+ from torch._jit_internal import boolean_dispatch, List, Optional, _overload, Tuple
53
+ try:
54
+ from torch.overrides import has_torch_function, handle_torch_function
55
+ except:
56
+ from torch._overrides import has_torch_function, handle_torch_function
57
+ Tensor = torch.Tensor
58
+
59
+ from torch.nn.functional import linear, pad, softmax, dropout
60
+
61
+ class MultiheadAttention(Module):
62
+ r"""Allows the model to jointly attend to information
63
+ from different representation subspaces.
64
+ See reference: Attention Is All You Need
65
+ .. math::
66
+ \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
67
+ \text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)
68
+ Args:
69
+ embed_dim: total dimension of the model.
70
+ num_heads: parallel attention heads.
71
+ dropout: a Dropout layer on attn_output_weights. Default: 0.0.
72
+ bias: add bias as module parameter. Default: True.
73
+ add_bias_kv: add bias to the key and value sequences at dim=0.
74
+ add_zero_attn: add a new batch of zeros to the key and
75
+ value sequences at dim=1.
76
+ kdim: total number of features in key. Default: None.
77
+ vdim: total number of features in value. Default: None.
78
+ Note: if kdim and vdim are None, they will be set to embed_dim such that
79
+ query, key, and value have the same number of features.
80
+ Examples::
81
+ >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
82
+ >>> attn_output, attn_output_weights = multihead_attn(query, key, value)
83
+ """
84
+ bias_k: Optional[torch.Tensor]
85
+ bias_v: Optional[torch.Tensor]
86
+
87
+ def __init__(self, embed_dim, num_heads, dropout=0., num_dummies=3, bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None):
88
+ super(MultiheadAttention, self).__init__()
89
+ self.num_dummies = num_dummies
90
+ self.embed_dim = embed_dim
91
+ self.kdim = kdim if kdim is not None else embed_dim
92
+ self.vdim = vdim if vdim is not None else embed_dim
93
+ self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
94
+
95
+ self.num_heads = num_heads
96
+ self.dropout = dropout
97
+ self.head_dim = embed_dim // num_heads
98
+ assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
99
+
100
+ vdim = vdim if vdim is not None else embed_dim
101
+ self.out_proj = Linear(vdim , vdim)
102
+
103
+ self.in_proj_bias = None
104
+ self.in_proj_weight = None
105
+ self.bias_k = self.bias_v = None
106
+ self.q_proj_weight = None
107
+ self.k_proj_weight = None
108
+ self.v_proj_weight = None
109
+
110
+ self.add_zero_attn = add_zero_attn
111
+
112
+ self._reset_parameters()
113
+
114
+ def _reset_parameters(self):
115
+ constant_(self.out_proj.bias, 0.)
116
+
117
+ def __setstate__(self, state):
118
+ # Support loading old MultiheadAttention checkpoints generated by v1.1.0
119
+ if '_qkv_same_embed_dim' not in state:
120
+ state['_qkv_same_embed_dim'] = True
121
+
122
+ super(MultiheadAttention, self).__setstate__(state)
123
+
124
+ def forward(self, query, key, value, key_padding_mask=None,
125
+ need_weights=True, attn_mask=None, dummy=True):
126
+ # type: (Tensor, Tensor, Tensor, Optional[Tensor], bool, Optional[Tensor]) -> Tuple[Tensor, Optional[Tensor]]
127
+ r"""
128
+ Args:
129
+ query, key, value: map a query and a set of key-value pairs to an output.
130
+ See "Attention Is All You Need" for more details.
131
+ key_padding_mask: if provided, specified padding elements in the key will
132
+ be ignored by the attention. When given a binary mask and a value is True,
133
+ the corresponding value on the attention layer will be ignored. When given
134
+ a byte mask and a value is non-zero, the corresponding value on the attention
135
+ layer will be ignored
136
+ need_weights: output attn_output_weights.
137
+ attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
138
+ the batches while a 3D mask allows to specify a different mask for the entries of each batch.
139
+ Shape:
140
+ - Inputs:
141
+ - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
142
+ the embedding dimension.
143
+ - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
144
+ the embedding dimension.
145
+ - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
146
+ the embedding dimension.
147
+ - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
148
+ If a ByteTensor is provided, the non-zero positions will be ignored while the position
149
+ with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the
150
+ value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
151
+ - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
152
+ 3D mask :math:`(N*\text{num_heads}, L, S)` where N is the batch size, L is the target sequence length,
153
+ S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked
154
+ positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
155
+ while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
156
+ is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
157
+ is provided, it will be added to the attention weight.
158
+ - Outputs:
159
+ - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
160
+ E is the embedding dimension.
161
+ - attn_output_weights: :math:`(N, L, S)` where N is the batch size,
162
+ L is the target sequence length, S is the source sequence length.
163
+ """
164
+ if not self._qkv_same_embed_dim:
165
+ return multi_head_attention_forward(
166
+ query, key, value, self.embed_dim, self.num_heads,
167
+ self.in_proj_weight, self.in_proj_bias,
168
+ self.bias_k, self.bias_v, self.add_zero_attn,
169
+ self.dropout, self.out_proj.weight, self.out_proj.bias,
170
+ training=self.training,
171
+ key_padding_mask=key_padding_mask, need_weights=need_weights,
172
+ attn_mask=attn_mask, use_separate_proj_weight=True,
173
+ q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
174
+ v_proj_weight=self.v_proj_weight, out_dim=self.vdim, num_dummies=self.num_dummies, dummy=dummy)
175
+ else:
176
+ return multi_head_attention_forward(
177
+ query, key, value, self.embed_dim, self.num_heads,
178
+ self.in_proj_weight, self.in_proj_bias,
179
+ self.bias_k, self.bias_v, self.add_zero_attn,
180
+ self.dropout, self.out_proj.weight, self.out_proj.bias,
181
+ training=self.training,
182
+ key_padding_mask=key_padding_mask, need_weights=need_weights,
183
+ attn_mask=attn_mask, out_dim=self.vdim, num_dummies=self.num_dummies, dummy=dummy)
184
+
185
+
186
+ def multi_head_attention_forward(query: Tensor,
187
+ key: Tensor,
188
+ value: Tensor,
189
+ embed_dim_to_check: int,
190
+ num_heads: int,
191
+ in_proj_weight: Tensor,
192
+ in_proj_bias: Tensor,
193
+ bias_k: Optional[Tensor],
194
+ bias_v: Optional[Tensor],
195
+ add_zero_attn: bool,
196
+ dropout_p: float,
197
+ out_proj_weight: Tensor,
198
+ out_proj_bias: Tensor,
199
+ training: bool = True,
200
+ key_padding_mask: Optional[Tensor] = None,
201
+ need_weights: bool = True,
202
+ attn_mask: Optional[Tensor] = None,
203
+ use_separate_proj_weight: bool = False,
204
+ q_proj_weight: Optional[Tensor] = None,
205
+ k_proj_weight: Optional[Tensor] = None,
206
+ v_proj_weight: Optional[Tensor] = None,
207
+ static_k: Optional[Tensor] = None,
208
+ static_v: Optional[Tensor] = None,
209
+ out_dim: Optional[Tensor] = None,
210
+ num_dummies=3,
211
+ dummy=True,
212
+ ) -> Tuple[Tensor, Optional[Tensor]]:
213
+ r"""
214
+ Args:
215
+ query, key, value: map a query and a set of key-value pairs to an output.
216
+ See "Attention Is All You Need" for more details.
217
+ embed_dim_to_check: total dimension of the model.
218
+ num_heads: parallel attention heads.
219
+ in_proj_weight, in_proj_bias: input projection weight and bias.
220
+ bias_k, bias_v: bias of the key and value sequences to be added at dim=0.
221
+ add_zero_attn: add a new batch of zeros to the key and
222
+ value sequences at dim=1.
223
+ dropout_p: probability of an element to be zeroed.
224
+ out_proj_weight, out_proj_bias: the output projection weight and bias.
225
+ training: apply dropout if is ``True``.
226
+ key_padding_mask: if provided, specified padding elements in the key will
227
+ be ignored by the attention. This is an binary mask. When the value is True,
228
+ the corresponding value on the attention layer will be filled with -inf.
229
+ need_weights: output attn_output_weights.
230
+ attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
231
+ the batches while a 3D mask allows to specify a different mask for the entries of each batch.
232
+ use_separate_proj_weight: the function accept the proj. weights for query, key,
233
+ and value in different forms. If false, in_proj_weight will be used, which is
234
+ a combination of q_proj_weight, k_proj_weight, v_proj_weight.
235
+ q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias.
236
+ static_k, static_v: static key and value used for attention operators.
237
+ Shape:
238
+ Inputs:
239
+ - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
240
+ the embedding dimension.
241
+ - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
242
+ the embedding dimension.
243
+ - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
244
+ the embedding dimension.
245
+ - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
246
+ If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions
247
+ will be unchanged. If a BoolTensor is provided, the positions with the
248
+ value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
249
+ - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
250
+ 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
251
+ S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked
252
+ positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
253
+ while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
254
+ are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
255
+ is provided, it will be added to the attention weight.
256
+ - static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
257
+ N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
258
+ - static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
259
+ N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
260
+ Outputs:
261
+ - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
262
+ E is the embedding dimension.
263
+ - attn_output_weights: :math:`(N, L, S)` where N is the batch size,
264
+ L is the target sequence length, S is the source sequence length.
265
+ """
266
+ if not torch.jit.is_scripting():
267
+ tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v,
268
+ out_proj_weight, out_proj_bias)
269
+ if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):
270
+ return handle_torch_function(
271
+ multi_head_attention_forward, tens_ops, query, key, value,
272
+ embed_dim_to_check, num_heads, in_proj_weight, in_proj_bias,
273
+ bias_k, bias_v, add_zero_attn, dropout_p, out_proj_weight,
274
+ out_proj_bias, training=training, key_padding_mask=key_padding_mask,
275
+ need_weights=need_weights, attn_mask=attn_mask,
276
+ use_separate_proj_weight=use_separate_proj_weight,
277
+ q_proj_weight=q_proj_weight, k_proj_weight=k_proj_weight,
278
+ v_proj_weight=v_proj_weight, static_k=static_k, static_v=static_v)
279
+ tgt_len, bsz, embed_dim = query.size()
280
+ assert embed_dim == embed_dim_to_check
281
+ # allow MHA to have different sizes for the feature dimension
282
+ assert key.size(0) == value.size(0) and key.size(1) == value.size(1)
283
+
284
+ head_dim = embed_dim // num_heads
285
+ v_head_dim = out_dim // num_heads
286
+ assert head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"
287
+ scaling = float(head_dim) ** -0.5
288
+
289
+ q = query * scaling
290
+ k = key
291
+ v = value
292
+
293
+ if attn_mask is not None:
294
+ assert attn_mask.dtype == torch.float32 or attn_mask.dtype == torch.float64 or \
295
+ attn_mask.dtype == torch.float16 or attn_mask.dtype == torch.uint8 or attn_mask.dtype == torch.bool, \
296
+ 'Only float, byte, and bool types are supported for attn_mask, not {}'.format(attn_mask.dtype)
297
+ if attn_mask.dtype == torch.uint8:
298
+ warnings.warn("Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
299
+ attn_mask = attn_mask.to(torch.bool)
300
+
301
+ if attn_mask.dim() == 2:
302
+ attn_mask = attn_mask.unsqueeze(0)
303
+ if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
304
+ raise RuntimeError('The size of the 2D attn_mask is not correct.')
305
+ elif attn_mask.dim() == 3:
306
+ if list(attn_mask.size()) != [bsz * num_heads, query.size(0), key.size(0)]:
307
+ raise RuntimeError('The size of the 3D attn_mask is not correct.')
308
+ else:
309
+ raise RuntimeError("attn_mask's dimension {} is not supported".format(attn_mask.dim()))
310
+ # attn_mask's dim is 3 now.
311
+
312
+ # convert ByteTensor key_padding_mask to bool
313
+ if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
314
+ warnings.warn("Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
315
+ key_padding_mask = key_padding_mask.to(torch.bool)
316
+
317
+ if bias_k is not None and bias_v is not None:
318
+ if static_k is None and static_v is None:
319
+ k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
320
+ v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
321
+ if attn_mask is not None:
322
+ attn_mask = pad(attn_mask, (0, 1))
323
+ if key_padding_mask is not None:
324
+ key_padding_mask = pad(key_padding_mask, (0, 1))
325
+ else:
326
+ assert static_k is None, "bias cannot be added to static key."
327
+ assert static_v is None, "bias cannot be added to static value."
328
+ else:
329
+ assert bias_k is None
330
+ assert bias_v is None
331
+
332
+ q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
333
+ if k is not None:
334
+ k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
335
+ if v is not None:
336
+ v = v.contiguous().view(-1, bsz * num_heads, v_head_dim).transpose(0, 1)
337
+
338
+ if static_k is not None:
339
+ assert static_k.size(0) == bsz * num_heads
340
+ assert static_k.size(2) == head_dim
341
+ k = static_k
342
+
343
+ if static_v is not None:
344
+ assert static_v.size(0) == bsz * num_heads
345
+ assert static_v.size(2) == v_head_dim
346
+ v = static_v
347
+
348
+ src_len = k.size(1)
349
+
350
+ if key_padding_mask is not None:
351
+ assert key_padding_mask.size(0) == bsz
352
+ assert key_padding_mask.size(1) == src_len
353
+
354
+ if add_zero_attn:
355
+ src_len += 1
356
+ k = torch.cat([k, torch.zeros((k.size(0), 1) + k.size()[2:], dtype=k.dtype, device=k.device)], dim=1)
357
+ v = torch.cat([v, torch.zeros((v.size(0), 1) + v.size()[2:], dtype=v.dtype, device=v.device)], dim=1)
358
+ if attn_mask is not None:
359
+ attn_mask = pad(attn_mask, (0, 1))
360
+ if key_padding_mask is not None:
361
+ key_padding_mask = pad(key_padding_mask, (0, 1))
362
+
363
+ attn_output_weights = torch.bmm(q, k.transpose(1, 2))
364
+ assert list(attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len]
365
+
366
+ if attn_mask is not None:
367
+ if attn_mask.dtype == torch.bool:
368
+ attn_output_weights.masked_fill_(attn_mask, float('-inf'))
369
+ else:
370
+ attn_output_weights += attn_mask
371
+
372
+
373
+ if key_padding_mask is not None:
374
+ attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
375
+ attn_output_weights = attn_output_weights.masked_fill(
376
+ key_padding_mask.unsqueeze(1).unsqueeze(2),
377
+ float('-inf'),
378
+ )
379
+ attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, src_len)
380
+
381
+ attn_output_weights = softmax(attn_output_weights, dim=-1)
382
+ attn_output_weights_d = dropout(attn_output_weights, p=dropout_p, training=training)
383
+ if dummy:
384
+ attn_output = torch.bmm(attn_output_weights_d[:, :, num_dummies:], v[:, num_dummies:,:])
385
+ else:
386
+ attn_output = torch.bmm(attn_output_weights_d, v)
387
+ assert list(attn_output.size()) == [bsz * num_heads, tgt_len, v_head_dim]
388
+ attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, out_dim)
389
+ attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
390
+
391
+ if need_weights:
392
+ # average attention weights over heads
393
+ attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
394
+ return attn_output, attn_output_weights.sum(dim=1) / num_heads
395
+ else:
396
+ return attn_output, None
third_party/cgdetr/cg_detr/inference.py ADDED
@@ -0,0 +1,480 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pprint
2
+ from tqdm import tqdm, trange
3
+ import numpy as np
4
+ import os
5
+ from collections import OrderedDict, defaultdict
6
+ from utils.basic_utils import AverageMeter
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+ import torch.backends.cudnn as cudnn
11
+ from torch.utils.data import DataLoader
12
+
13
+ from cg_detr.config import TestOptions
14
+ from cg_detr.model import build_model
15
+ from cg_detr.span_utils import span_cxw_to_xx
16
+ from cg_detr.start_end_dataset import StartEndDataset, start_end_collate, prepare_batch_inputs
17
+ from cg_detr.postprocessing_cg_detr import PostProcessorDETR
18
+ from standalone_eval.eval import eval_submission
19
+ from utils.basic_utils import save_jsonl, save_json
20
+ from utils.temporal_nms import temporal_nms
21
+
22
+ import logging
23
+
24
+ logger = logging.getLogger(__name__)
25
+ logging.basicConfig(format="%(asctime)s.%(msecs)03d:%(levelname)s:%(name)s - %(message)s",
26
+ datefmt="%Y-%m-%d %H:%M:%S",
27
+ level=logging.INFO)
28
+
29
+
30
+ def post_processing_mr_nms(mr_res, nms_thd, max_before_nms, max_after_nms):
31
+ mr_res_after_nms = []
32
+ for e in mr_res:
33
+ e["pred_relevant_windows"] = temporal_nms(
34
+ e["pred_relevant_windows"][:max_before_nms],
35
+ nms_thd=nms_thd,
36
+ max_after_nms=max_after_nms
37
+ )
38
+ mr_res_after_nms.append(e)
39
+ return mr_res_after_nms
40
+
41
+
42
+ def eval_epoch_post_processing(submission, opt, gt_data, save_submission_filename):
43
+ # IOU_THDS = (0.5, 0.7)
44
+ logger.info("Saving/Evaluating before nms results")
45
+ submission_path = os.path.join(opt.results_dir, save_submission_filename)
46
+ save_jsonl(submission, submission_path)
47
+
48
+ if opt.eval_split_name in ["val"]: # since test_public has no GT
49
+ metrics = eval_submission(
50
+ submission, gt_data,
51
+ verbose=opt.debug, match_number=not opt.debug
52
+ )
53
+ save_metrics_path = submission_path.replace(".jsonl", "_metrics.json")
54
+ save_json(metrics, save_metrics_path, save_pretty=True, sort_keys=False)
55
+ latest_file_paths = [submission_path, save_metrics_path]
56
+ else:
57
+ metrics = None
58
+ latest_file_paths = [submission_path, ]
59
+
60
+ if opt.nms_thd != -1:
61
+ logger.info("[MR] Performing nms with nms_thd {}".format(opt.nms_thd))
62
+ submission_after_nms = post_processing_mr_nms(
63
+ submission, nms_thd=opt.nms_thd,
64
+ max_before_nms=opt.max_before_nms, max_after_nms=opt.max_after_nms
65
+ )
66
+
67
+ logger.info("Saving/Evaluating nms results")
68
+ submission_nms_path = submission_path.replace(".jsonl", "_nms_thd_{}.jsonl".format(opt.nms_thd))
69
+ save_jsonl(submission_after_nms, submission_nms_path)
70
+ if opt.eval_split_name == "val":
71
+ metrics_nms = eval_submission(
72
+ submission_after_nms, gt_data,
73
+ verbose=opt.debug, match_number=not opt.debug
74
+ )
75
+ save_metrics_nms_path = submission_nms_path.replace(".jsonl", "_metrics.json")
76
+ save_json(metrics_nms, save_metrics_nms_path, save_pretty=True, sort_keys=False)
77
+ latest_file_paths += [submission_nms_path, save_metrics_nms_path]
78
+ else:
79
+ metrics_nms = None
80
+ latest_file_paths = [submission_nms_path, ]
81
+ else:
82
+ metrics_nms = None
83
+ return metrics, metrics_nms, latest_file_paths
84
+
85
+
86
+ # for HL
87
+ @torch.no_grad()
88
+ def compute_hl_results(model, eval_loader, opt, epoch_i=None, criterion=None, tb_writer=None):
89
+ model.eval()
90
+ if criterion:
91
+ assert eval_loader.dataset.load_labels
92
+ criterion.eval()
93
+
94
+ loss_meters = defaultdict(AverageMeter)
95
+ write_tb = tb_writer is not None and epoch_i is not None
96
+
97
+ mr_res = []
98
+
99
+ topk = 5 # top-5 map
100
+
101
+ video_ap_collected = []
102
+ for batch in tqdm(eval_loader, desc="compute st ed scores"):
103
+ query_meta = batch[0]
104
+
105
+ model_inputs, targets = prepare_batch_inputs(batch[1], opt.device, non_blocking=opt.pin_memory)
106
+
107
+ outputs = model(**model_inputs)
108
+
109
+ # loss meters
110
+ # if criterion:
111
+ # loss_dict = criterion(outputs, targets)
112
+ # weight_dict = criterion.weight_dict
113
+ # print(loss_dict)
114
+ # print(weight_dict)
115
+ # print('#######')
116
+ # {'loss_saliency': tensor(18.1374, device='cuda:0')}
117
+ # {'loss_span': 10, 'loss_giou': 1, 'loss_label': 4, 'loss_saliency': 1.0, 'loss_ms_align': 1.0,
118
+ # 'loss_distill': 1.0, 'loss_span_0': 10, 'loss_giou_0': 1, 'loss_label_0': 4, 'loss_ms_align_0': 1.0,
119
+ # 'loss_distill_0': 1.0}
120
+ # losses=0.
121
+ # print(loss_dict.keys(), weight_dict.keys())
122
+ # losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)
123
+ # loss_dict["loss_overall"] = float(losses) # for logging only
124
+ # print(loss_dict.items())
125
+ #
126
+ # print(weight_dict.items())
127
+ # for k, v in loss_dict.items():
128
+ # loss_meters[k].update(float(v) * weight_dict[k] if k in weight_dict else float(v))
129
+
130
+
131
+ preds = outputs['saliency_scores'].clone().detach()
132
+
133
+ for meta, pred in zip(query_meta, preds):
134
+ pred = pred
135
+ label = meta['label'] # raw label
136
+
137
+ video_ap = []
138
+ # Follow the UMT code "https://github.com/TencentARC/UMT/blob/main/datasets/tvsum.py"
139
+
140
+ if opt.dset_name in ["tvsum"]:
141
+ for i in range(20):
142
+ pred=pred.cpu()
143
+ cur_pred = pred[:len(label)]
144
+ inds = torch.argsort(cur_pred, descending=True, dim=-1)
145
+
146
+ # video_id = self.get_video_id(idx)
147
+ cur_label = torch.Tensor(label)[:, i]
148
+ cur_label = torch.where(cur_label > cur_label.median(), 1.0, .0)
149
+
150
+ cur_label = cur_label[inds].tolist()[:topk]
151
+
152
+ # if (num_gt := sum(cur_label)) == 0:
153
+ num_gt = sum(cur_label)
154
+ if num_gt == 0:
155
+ video_ap.append(0)
156
+ continue
157
+
158
+ hits = ap = rec = 0
159
+ prc = 1
160
+
161
+ for j, gt in enumerate(cur_label):
162
+ hits += gt
163
+
164
+ _rec = hits / num_gt
165
+ _prc = hits / (j + 1)
166
+
167
+ ap += (_rec - rec) * (prc + _prc) / 2
168
+ rec, prc = _rec, _prc
169
+
170
+ video_ap.append(ap)
171
+
172
+ elif opt.dset_name in ["youtube_uni"]:
173
+ cur_pred = pred[:len(label)]
174
+ # if opt.dset_name == "tvsum_sfc":
175
+ cur_pred = cur_pred.cpu()
176
+ inds = torch.argsort(cur_pred, descending=True, dim=-1)
177
+
178
+
179
+ cur_label = torch.Tensor(label).squeeze()[inds].tolist()
180
+
181
+ num_gt = sum(cur_label)
182
+ if num_gt == 0:
183
+ video_ap.append(0)
184
+ continue
185
+
186
+ hits = ap = rec = 0
187
+ prc = 1
188
+
189
+ for j, gt in enumerate(cur_label):
190
+ hits += gt
191
+
192
+ _rec = hits / num_gt
193
+ _prc = hits / (j + 1)
194
+
195
+ ap += (_rec - rec) * (prc + _prc) / 2
196
+ rec, prc = _rec, _prc
197
+
198
+ video_ap.append(float(ap))
199
+ else:
200
+ print("No such dataset")
201
+ exit(-1)
202
+
203
+ video_ap_collected.append(video_ap)
204
+
205
+ mean_ap = np.mean(video_ap_collected)
206
+ submmission = dict(mAP=round(mean_ap, 5))
207
+
208
+
209
+ # tensorboard writer
210
+ if write_tb and criterion:
211
+ for k, v in loss_meters.items():
212
+ tb_writer.add_scalar("Eval/{}".format(k), v.avg, epoch_i + 1)
213
+
214
+ return submmission, loss_meters
215
+
216
+
217
+
218
+ @torch.no_grad()
219
+ def compute_mr_results(model, eval_loader, opt, epoch_i=None, criterion=None, tb_writer=None):
220
+ model.eval()
221
+ if criterion:
222
+ assert eval_loader.dataset.load_labels
223
+ criterion.eval()
224
+
225
+ loss_meters = defaultdict(AverageMeter)
226
+ write_tb = tb_writer is not None and epoch_i is not None
227
+
228
+ mr_res = []
229
+ for batch in tqdm(eval_loader, desc="compute st ed scores"):
230
+ query_meta = batch[0]
231
+
232
+ model_inputs, targets = prepare_batch_inputs(batch[1], opt.device, non_blocking=opt.pin_memory)
233
+
234
+ outputs = model(**model_inputs)
235
+ prob = F.softmax(outputs["pred_logits"], -1) # (batch_size, #queries, #classes=2)
236
+ if opt.span_loss_type == "l1":
237
+ scores = prob[..., 0] # * (batch_size, #queries) foreground label is 0, we directly take it
238
+ pred_spans = outputs["pred_spans"] # (bsz, #queries, 2)
239
+ _saliency_scores = outputs["saliency_scores"].half() # (bsz, L)
240
+ saliency_scores = []
241
+ valid_vid_lengths = model_inputs["src_vid_mask"].sum(1).cpu().tolist()
242
+ for j in range(len(valid_vid_lengths)):
243
+ saliency_scores.append(_saliency_scores[j, :int(valid_vid_lengths[j])].tolist())
244
+ else:
245
+ bsz, n_queries = outputs["pred_spans"].shape[:2] # # (bsz, #queries, max_v_l *2)
246
+ pred_spans_logits = outputs["pred_spans"].view(bsz, n_queries, 2, opt.max_v_l)
247
+ pred_span_scores, pred_spans = F.softmax(pred_spans_logits, dim=-1).max(-1) # 2 * (bsz, #queries, 2)
248
+ scores = torch.prod(pred_span_scores, 2) # (bsz, #queries)
249
+ pred_spans[:, 1] += 1
250
+ pred_spans *= opt.clip_length
251
+
252
+ # compose predictions
253
+ for idx, (meta, spans, score) in enumerate(zip(query_meta, pred_spans.cpu(), scores.cpu())):
254
+ if opt.span_loss_type == "l1":
255
+ spans = span_cxw_to_xx(spans) * meta["duration"]
256
+ spans = torch.clamp(spans, 0, meta["duration"])
257
+ # # (#queries, 3), [st(float), ed(float), score(float)]
258
+ cur_ranked_preds = torch.cat([spans, score[:, None]], dim=1).tolist()
259
+ if not opt.no_sort_results:
260
+ cur_ranked_preds = sorted(cur_ranked_preds, key=lambda x: x[2], reverse=True)
261
+ cur_ranked_preds = [[float(f"{e:.4f}") for e in row] for row in cur_ranked_preds]
262
+ cur_query_pred = dict(
263
+ qid=meta["qid"],
264
+ query=meta["query"],
265
+ vid=meta["vid"],
266
+ pred_relevant_windows=cur_ranked_preds,
267
+ pred_saliency_scores=saliency_scores[idx]
268
+ )
269
+ mr_res.append(cur_query_pred)
270
+
271
+ if criterion:
272
+ loss_dict = criterion(outputs, targets)
273
+ weight_dict = criterion.weight_dict
274
+ losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)
275
+ loss_dict["loss_overall"] = float(losses) # for logging only
276
+ for k, v in loss_dict.items():
277
+ loss_meters[k].update(float(v) * weight_dict[k] if k in weight_dict else float(v))
278
+
279
+ if opt.debug:
280
+ break
281
+
282
+ if write_tb and criterion:
283
+ for k, v in loss_meters.items():
284
+ tb_writer.add_scalar("Eval/{}".format(k), v.avg, epoch_i + 1)
285
+
286
+ if opt.dset_name in ['hl']:
287
+ post_processor = PostProcessorDETR(
288
+ clip_length=opt.clip_length, min_ts_val=0, max_ts_val=150,
289
+ min_w_l=2, max_w_l=150, move_window_method="left",
290
+ process_func_names=("clip_ts", "round_multiple")
291
+ )
292
+ elif opt.dset_name in ['charadesSTA']:
293
+ if opt.v_feat_dim == 4096: # vgg
294
+ post_processor = PostProcessorDETR(
295
+ clip_length=opt.clip_length, min_ts_val=0, max_ts_val=360,
296
+ min_w_l=12, max_w_l=360, move_window_method="left",
297
+ process_func_names=("clip_ts", "round_multiple")
298
+ )
299
+ else:
300
+ post_processor = PostProcessorDETR(
301
+ clip_length=opt.clip_length, min_ts_val=0, max_ts_val=150,
302
+ min_w_l=2, max_w_l=60, move_window_method="left",
303
+ process_func_names=("clip_ts", "round_multiple")
304
+ )
305
+ else:
306
+ post_processor = PostProcessorDETR(
307
+ clip_length=opt.clip_length, min_ts_val=0, max_ts_val=50000,
308
+ min_w_l=0, max_w_l=50000, move_window_method="left",
309
+ process_func_names=(["round_multiple"])
310
+ )
311
+
312
+ mr_res = post_processor(mr_res)
313
+ return mr_res, loss_meters
314
+
315
+
316
+ def get_eval_res(model, eval_loader, opt, epoch_i, criterion, tb_writer):
317
+ """compute and save query and video proposal embeddings"""
318
+ eval_res, eval_loss_meters = compute_mr_results(model, eval_loader, opt, epoch_i, criterion, tb_writer) # list(dict)
319
+ return eval_res, eval_loss_meters
320
+
321
+
322
+ def eval_epoch(model, eval_dataset, opt, save_submission_filename, epoch_i=None, criterion=None, tb_writer=None):
323
+ logger.info("Generate submissions")
324
+ model.eval()
325
+ if criterion is not None and eval_dataset.load_labels:
326
+ criterion.eval()
327
+ else:
328
+ criterion = None
329
+
330
+ if opt.dset_name == 'tacos':
331
+ shuffle = True
332
+ else:
333
+ shuffle = False
334
+
335
+ eval_loader = DataLoader(
336
+ eval_dataset,
337
+ collate_fn=start_end_collate,
338
+ batch_size=opt.eval_bsz,
339
+ num_workers=opt.num_workers,
340
+ shuffle=shuffle,
341
+ pin_memory=opt.pin_memory
342
+ )
343
+
344
+
345
+ # tvsum
346
+ if opt.dset_name in ['tvsum', 'youtube_uni']:
347
+ metrics, eval_loss_meters = compute_hl_results(model, eval_loader, opt, epoch_i, criterion, tb_writer)
348
+
349
+ # to match original save format
350
+ submission = [
351
+ {"brief": metrics}
352
+ ]
353
+ submission_path = os.path.join(opt.results_dir, "latest_metric.jsonl")
354
+ save_jsonl(submission, submission_path)
355
+
356
+ return submission[0], submission[0], eval_loss_meters, [submission_path]
357
+
358
+ else:
359
+ submission, eval_loss_meters = get_eval_res(model, eval_loader, opt, epoch_i, criterion, tb_writer)
360
+
361
+ if opt.dset_name in ['charadesSTA', 'tacos', 'nlq']:
362
+ new_submission = []
363
+ for s in submission:
364
+ s.pop('pred_saliency_scores', None)
365
+ new_submission.append(s)
366
+ submission = new_submission
367
+
368
+ if opt.no_sort_results:
369
+ save_submission_filename = save_submission_filename.replace(".jsonl", "_unsorted.jsonl")
370
+ metrics, metrics_nms, latest_file_paths = eval_epoch_post_processing(
371
+ submission, opt, eval_dataset.data, save_submission_filename)
372
+ return metrics, metrics_nms, eval_loss_meters, latest_file_paths
373
+
374
+
375
+ def setup_model(opt):
376
+ """setup model/optimizer/scheduler and load checkpoints when needed"""
377
+ logger.info("setup model/optimizer/scheduler")
378
+ model, criterion = build_model(opt)
379
+ if opt.device.type == "cuda":
380
+ logger.info("CUDA enabled.")
381
+ model.to(opt.device)
382
+ criterion.to(opt.device)
383
+
384
+ param_dicts = [{"params": [p for n, p in model.named_parameters() if p.requires_grad]}]
385
+ optimizer = torch.optim.AdamW(param_dicts, lr=opt.lr, weight_decay=opt.wd)
386
+ lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, opt.lr_drop)
387
+
388
+ if opt.resume is not None:
389
+ logger.info(f"Load checkpoint from {opt.resume}")
390
+ checkpoint = torch.load(opt.resume, map_location="cpu")
391
+ from collections import OrderedDict
392
+ new_state_dict = OrderedDict()
393
+ if 'pt' in opt.resume[:-4]:
394
+ if 'asr' in opt.resume[:25]:
395
+ model.load_state_dict(checkpoint["model"])
396
+ else:
397
+ for k, v in checkpoint["model"].items():
398
+ name = k[7:] # remove `module.`
399
+ new_state_dict[name] = v
400
+ # model.load_state_dict(checkpoint["model"])
401
+ model.load_state_dict(new_state_dict)
402
+ else:
403
+ model.load_state_dict(checkpoint["model"])
404
+ if opt.resume_all:
405
+ optimizer.load_state_dict(checkpoint['optimizer'])
406
+ lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
407
+ opt.start_epoch = checkpoint['epoch'] + 1
408
+ logger.info(f"Loaded model saved at epoch {checkpoint['epoch']} from checkpoint: {opt.resume}")
409
+ else:
410
+ logger.warning("If you intend to evaluate the model, please specify --resume with ckpt path")
411
+
412
+ return model, criterion, optimizer, lr_scheduler
413
+
414
+
415
+ def start_inference(train_opt=None, split=None, splitfile=None):
416
+ if train_opt is not None:
417
+ opt = TestOptions().parse(train_opt.a_feat_dir)
418
+ else:
419
+ opt = TestOptions().parse()
420
+ if split is not None:
421
+ opt.eval_split_name = split
422
+ if splitfile is not None:
423
+ opt.eval_path = splitfile
424
+
425
+ print(opt.eval_split_name)
426
+ print(opt.eval_path)
427
+ logger.info("Setup config, data and model...")
428
+
429
+
430
+ cudnn.benchmark = True
431
+ cudnn.deterministic = False
432
+
433
+ assert opt.eval_path is not None
434
+ if opt.eval_split_name == 'val':
435
+ loadlabel = True
436
+ else:
437
+ loadlabel = False
438
+
439
+ eval_dataset = StartEndDataset(
440
+ dset_name=opt.dset_name,
441
+ data_path=opt.eval_path,
442
+ v_feat_dirs=opt.v_feat_dirs,
443
+ q_feat_dir=opt.t_feat_dir,
444
+ q_feat_type="last_hidden_state",
445
+ max_q_l=opt.max_q_l,
446
+ max_v_l=opt.max_v_l,
447
+ ctx_mode=opt.ctx_mode,
448
+ data_ratio=opt.data_ratio,
449
+ normalize_v=not opt.no_norm_vfeat,
450
+ normalize_t=not opt.no_norm_tfeat,
451
+ clip_len=opt.clip_length,
452
+ max_windows=opt.max_windows,
453
+ load_labels=loadlabel, # opt.eval_split_name == "val",
454
+ span_loss_type=opt.span_loss_type,
455
+ txt_drop_ratio=0,
456
+ dset_domain=opt.dset_domain,
457
+ )
458
+
459
+
460
+
461
+ model, criterion, _, _ = setup_model(opt)
462
+
463
+ save_submission_filename = "hl_{}_submission.jsonl".format(
464
+ opt.eval_split_name)
465
+ # save_submission_filename = "inference_{}_{}_{}_preds.jsonl".format(
466
+ # opt.dset_name, opt.eval_split_name, opt.eval_id)
467
+ logger.info("Starting inference...")
468
+ with torch.no_grad():
469
+ metrics_no_nms, metrics_nms, eval_loss_meters, latest_file_paths = \
470
+ eval_epoch(model, eval_dataset, opt, save_submission_filename, criterion=criterion)
471
+ if opt.eval_split_name == 'val':
472
+ logger.info("metrics_no_nms {}".format(pprint.pformat(metrics_no_nms["brief"], indent=4)))
473
+ if metrics_nms is not None:
474
+ logger.info("metrics_nms {}".format(pprint.pformat(metrics_nms["brief"], indent=4)))
475
+
476
+ from sys import argv
477
+ if __name__ == '__main__':
478
+ _,_,_,_,split,_,splitfile = argv
479
+
480
+ start_inference(split=split, splitfile=splitfile)
third_party/cgdetr/cg_detr/matcher.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+ """
3
+ Modules to compute the matching cost and solve the corresponding LSAP.
4
+ """
5
+ import torch
6
+ from scipy.optimize import linear_sum_assignment
7
+ from torch import nn
8
+ import torch.nn.functional as F
9
+ from third_party.cgdetr.cg_detr.span_utils import generalized_temporal_iou, span_cxw_to_xx
10
+
11
+
12
+ class HungarianMatcher(nn.Module):
13
+ """This class computes an assignment between the targets and the predictions of the network
14
+
15
+ For efficiency reasons, the targets don't include the no_object. Because of this, in general,
16
+ there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
17
+ while the others are un-matched (and thus treated as non-objects).
18
+ """
19
+ def __init__(self, cost_class: float = 1, cost_span: float = 1, cost_giou: float = 1,
20
+ span_loss_type: str = "l1", max_v_l: int = 75):
21
+ """Creates the matcher
22
+
23
+ Params:
24
+ cost_span: This is the relative weight of the L1 error of the span coordinates in the matching cost
25
+ cost_giou: This is the relative weight of the giou loss of the spans in the matching cost
26
+ """
27
+ super().__init__()
28
+ self.cost_class = cost_class
29
+ self.cost_span = cost_span
30
+ self.cost_giou = cost_giou
31
+ self.span_loss_type = span_loss_type
32
+ self.max_v_l = max_v_l
33
+ self.foreground_label = 0
34
+ assert cost_class != 0 or cost_span != 0 or cost_giou != 0, "all costs cant be 0"
35
+
36
+ @torch.no_grad()
37
+ def forward(self, outputs, targets):
38
+ """ Performs the matching
39
+
40
+ Params:
41
+ outputs: This is a dict that contains at least these entries:
42
+ "pred_spans": Tensor of dim [batch_size, num_queries, 2] with the predicted span coordinates,
43
+ in normalized (cx, w) format
44
+ ""pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
45
+
46
+ targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:
47
+ "spans": Tensor of dim [num_target_spans, 2] containing the target span coordinates. The spans are
48
+ in normalized (cx, w) format
49
+
50
+ Returns:
51
+ A list of size batch_size, containing tuples of (index_i, index_j) where:
52
+ - index_i is the indices of the selected predictions (in order)
53
+ - index_j is the indices of the corresponding selected targets (in order)
54
+ For each batch element, it holds:
55
+ len(index_i) = len(index_j) = min(num_queries, num_target_spans)
56
+ """
57
+ bs, num_queries = outputs["pred_spans"].shape[:2]
58
+ targets = targets["span_labels"]
59
+ # import pdb; pdb.set_trace()
60
+
61
+ # Also concat the target labels and spans
62
+ out_prob = outputs["pred_logits"].flatten(0, 1).softmax(-1) # [batch_size * num_queries, num_classes]
63
+ tgt_spans = torch.cat([v["spans"] for v in targets]) # [num_target_spans in batch, 2]
64
+ tgt_ids = torch.full([len(tgt_spans)], self.foreground_label) # [total #spans in the batch]
65
+
66
+ # Compute the classification cost. Contrary to the loss, we don't use the NLL,
67
+ # but approximate it in 1 - prob[target class].
68
+ # The 1 is a constant that doesn't change the matching, it can be omitted.
69
+ cost_class = -out_prob[:, tgt_ids] # [batch_size * num_queries, total #spans in the batch]
70
+
71
+ if self.span_loss_type == "l1":
72
+ # We flatten to compute the cost matrices in a batch
73
+ out_spans = outputs["pred_spans"].flatten(0, 1) # [batch_size * num_queries, 2]
74
+
75
+ # Compute the L1 cost between spans
76
+ cost_span = torch.cdist(out_spans.type(torch.float32), tgt_spans.type(torch.float32), p=1) # [batch_size * num_queries, total #spans in the batch]
77
+ cost_span = cost_span.type(torch.bfloat16)
78
+
79
+ # Compute the giou cost between spans
80
+ # [batch_size * num_queries, total #spans in the batch]
81
+ cost_giou = - generalized_temporal_iou(span_cxw_to_xx(out_spans), span_cxw_to_xx(tgt_spans))
82
+ else:
83
+ pred_spans = outputs["pred_spans"] # (bsz, #queries, max_v_l * 2)
84
+ pred_spans = pred_spans.view(bs * num_queries, 2, self.max_v_l).softmax(-1) # (bsz * #queries, 2, max_v_l)
85
+ cost_span = - pred_spans[:, 0][:, tgt_spans[:, 0]] - \
86
+ pred_spans[:, 1][:, tgt_spans[:, 1]] # (bsz * #queries, #spans)
87
+ # pred_spans = pred_spans.repeat(1, n_spans, 1, 1).flatten(0, 1) # (bsz * #queries * #spans, max_v_l, 2)
88
+ # tgt_spans = tgt_spans.view(1, n_spans, 2).repeat(bs * num_queries, 1, 1).flatten(0, 1) # (bsz * #queries * #spans, 2)
89
+ # cost_span = pred_spans[tgt_spans]
90
+ # cost_span = cost_span.view(bs * num_queries, n_spans)
91
+
92
+ # giou
93
+ cost_giou = 0
94
+
95
+ # Final cost matrix
96
+ # import ipdb; ipdb.set_trace()
97
+ C = self.cost_span * cost_span + self.cost_giou * cost_giou + self.cost_class * cost_class
98
+ C = C.view(bs, num_queries, -1).cpu()
99
+
100
+ sizes = [len(v["spans"]) for v in targets]
101
+ indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))]
102
+ return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]
103
+
104
+
105
+ def build_matcher(args):
106
+ return HungarianMatcher(
107
+ cost_span=args.set_cost_span, cost_giou=args.set_cost_giou,
108
+ cost_class=args.set_cost_class, span_loss_type=args.span_loss_type, max_v_l=args.max_v_l
109
+ )
third_party/cgdetr/cg_detr/misc.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ @torch.no_grad()
5
+ def accuracy(output, target, topk=(1,)):
6
+ """Computes the precision@k for the specified values of k
7
+ output: (#items, #classes)
8
+ target: int,
9
+ """
10
+ maxk = max(topk)
11
+ num_items = output.size(0)
12
+
13
+ _, pred = output.topk(maxk, 1, True, True)
14
+ pred = pred.t()
15
+ correct = pred.eq(target)
16
+
17
+ res = []
18
+ for k in topk:
19
+ correct_k = correct[:k].view(-1).float().sum(0)
20
+ res.append(correct_k.mul_(100.0 / num_items))
21
+ return res
third_party/cgdetr/cg_detr/model.py ADDED
@@ -0,0 +1,1178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+ """
3
+ CG-DETR model and criterion classes.
4
+ """
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import nn
8
+
9
+ from third_party.cgdetr.cg_detr.span_utils import generalized_temporal_iou, span_cxw_to_xx
10
+
11
+ from third_party.cgdetr.cg_detr.matcher import build_matcher
12
+ from third_party.cgdetr.cg_detr.transformer import build_transformer, TransformerEncoderLayer, TransformerEncoder
13
+ from third_party.cgdetr.cg_detr.position_encoding import build_position_encoding
14
+ from third_party.cgdetr.cg_detr.misc import accuracy
15
+ import numpy as np
16
+ import copy
17
+
18
+ def inverse_sigmoid(x, eps=1e-3):
19
+ x = x.clamp(min=0, max=1)
20
+ x1 = x.clamp(min=eps)
21
+ x2 = (1 - x).clamp(min=eps)
22
+ return torch.log(x1/x2)
23
+
24
+ def init_weights(module):
25
+ if isinstance(module, (nn.Linear, nn.Embedding)):
26
+ module.weight.data.normal_(mean=0.0, std=0.02)
27
+ elif isinstance(module, nn.LayerNorm):
28
+ module.bias.data.zero_()
29
+ module.weight.data.fill_(1.0)
30
+
31
+ if isinstance(module, nn.Linear) and module.bias is not None:
32
+ module.bias.data.zero_()
33
+
34
+ def find_nth(vid, underline, n):
35
+ max_len = len(vid)
36
+ start = vid.find(underline)
37
+ while start >= 0 and n > 1:
38
+ start = vid.find(underline, start+len(underline))
39
+ n -= 1
40
+ if start == -1:
41
+ start = max_len
42
+ return start
43
+
44
+ def element_wise_list_equal(listA, listB):
45
+ res = []
46
+ for a, b in zip(listA, listB):
47
+ if a==b:
48
+ res.append(True)
49
+ else:
50
+ res.append(False)
51
+ return res
52
+
53
+ class CGDETR(nn.Module):
54
+ """ CG DETR. """
55
+
56
+ def __init__(self, transformer, position_embed, txt_position_embed, txt_dim, vid_dim,
57
+ num_queries, input_dropout, aux_loss=False,
58
+ contrastive_align_loss=False, contrastive_hdim=64,
59
+ max_v_l=75, span_loss_type="l1", use_txt_pos=False, n_input_proj=2, aud_dim=0, args=None):
60
+ """ Initializes the model.
61
+ Parameters:
62
+ transformer: torch module of the transformer architecture. See transformer.py
63
+ position_embed: torch module of the position_embedding, See position_encoding.py
64
+ txt_position_embed: position_embedding for text
65
+ txt_dim: int, text query input dimension
66
+ vid_dim: int, video feature input dimension
67
+ num_queries: number of object queries, ie detection slot. This is the maximal number of objects
68
+ CG-DETR can detect in a single video.
69
+ aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
70
+ contrastive_align_loss: If true, perform span - tokens contrastive learning
71
+ contrastive_hdim: dimension used for projecting the embeddings before computing contrastive loss
72
+ max_v_l: int, maximum #clips in videos
73
+ span_loss_type: str, one of [l1, ce]
74
+ l1: (center-x, width) regression.
75
+ ce: (st_idx, ed_idx) classification.
76
+ # foreground_thd: float, intersection over prediction >= foreground_thd: labeled as foreground
77
+ # background_thd: float, intersection over prediction <= background_thd: labeled background
78
+ """
79
+ super().__init__()
80
+ self.args=args
81
+ self.num_queries = num_queries
82
+ self.transformer = transformer
83
+ self.position_embed = position_embed
84
+ self.txt_position_embed = txt_position_embed
85
+ hidden_dim = transformer.d_model
86
+ self.span_loss_type = span_loss_type
87
+ self.max_v_l = max_v_l
88
+ span_pred_dim = 2 if span_loss_type == "l1" else max_v_l * 2
89
+ self.span_embed = MLP(hidden_dim, hidden_dim, span_pred_dim, 3)
90
+ self.class_embed = nn.Linear(hidden_dim, 2) # 0: background, 1: foreground
91
+ self.token_type_embeddings = nn.Embedding(2, hidden_dim)
92
+ self.token_type_embeddings.apply(init_weights)
93
+ self.use_txt_pos = use_txt_pos
94
+ self.n_input_proj = n_input_proj
95
+ self.query_embed = nn.Embedding(num_queries, 2)
96
+ relu_args = [True] * 3
97
+ relu_args[n_input_proj-1] = False
98
+ self.input_txt_proj = nn.Sequential(*[
99
+ LinearLayer(txt_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[0]),
100
+ LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[1]),
101
+ LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[2])
102
+ ][:n_input_proj])
103
+ self.input_vid_proj = nn.Sequential(*[
104
+ LinearLayer(vid_dim + aud_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[0]),
105
+ LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[1]),
106
+ LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[2])
107
+ ][:n_input_proj])
108
+ self.contrastive_align_loss = contrastive_align_loss
109
+ if contrastive_align_loss:
110
+ self.contrastive_align_projection_query = nn.Linear(hidden_dim, contrastive_hdim)
111
+ self.contrastive_align_projection_txt = nn.Linear(hidden_dim, contrastive_hdim)
112
+ self.contrastive_align_projection_vid = nn.Linear(hidden_dim, contrastive_hdim)
113
+
114
+ self.saliency_proj1 = nn.Linear(hidden_dim, hidden_dim)
115
+ self.saliency_proj2 = nn.Linear(hidden_dim, hidden_dim)
116
+ self.aux_loss = aux_loss
117
+ self.hidden_dim = hidden_dim
118
+ self.global_rep_token = torch.nn.Parameter(torch.randn(args.total_prompts, hidden_dim))
119
+ self.global_rep_pos = torch.nn.Parameter(torch.randn(1, hidden_dim))
120
+ self.moment_rep_token = torch.nn.Parameter(torch.randn(hidden_dim))
121
+ self.moment_rep_pos = torch.nn.Parameter(torch.randn(hidden_dim))
122
+
123
+ self.dummy_rep_token = torch.nn.Parameter(torch.randn(args.num_dummies, hidden_dim))
124
+ self.dummy_rep_pos = torch.nn.Parameter(torch.randn(args.num_dummies, hidden_dim))
125
+ normalize_before = False
126
+ self.sent_rep_token = torch.nn.Parameter(torch.randn(hidden_dim))
127
+ self.sent_rep_pos = torch.nn.Parameter(torch.randn(hidden_dim))
128
+
129
+ self.txt_proj_linear = LinearLayer(txt_dim, hidden_dim, layer_norm=True)
130
+
131
+ input_txt_sa_proj = TransformerEncoderLayer(hidden_dim, 8, self.args.dim_feedforward, 0.1, "prelu", normalize_before)
132
+ txtproj_encoder_norm = nn.LayerNorm(hidden_dim) if normalize_before else None
133
+ self.txtproj_encoder = TransformerEncoder(input_txt_sa_proj, args.dummy_layers, txtproj_encoder_norm)
134
+
135
+ scls_encoder_layer = TransformerEncoderLayer(hidden_dim, 8, self.args.dim_feedforward, 0.1, "prelu", normalize_before)
136
+ scls_encoder_norm = nn.LayerNorm(hidden_dim) if normalize_before else None
137
+ self.scls_encoder = TransformerEncoder(scls_encoder_layer, args.sent_layers, scls_encoder_norm)
138
+
139
+ def forward(self, src_txt, src_txt_mask, src_vid, src_vid_mask, vid=None, qid=None, src_aud=None, src_aud_mask=None, targets=None, prompt_token=None):
140
+ """The forward expects two tensors:
141
+ - src_txt: [batch_size, L_txt, D_txt]
142
+ - src_txt_mask: [batch_size, L_txt], containing 0 on padded pixels,
143
+ will convert to 1 as padding later for transformer
144
+ - src_vid: [batch_size, L_vid, D_vid]
145
+ - src_vid_mask: [batch_size, L_vid], containing 0 on padded pixels,
146
+ will convert to 1 as padding later for transformer
147
+
148
+ It returns a dict with the following elements:
149
+ - "pred_spans": The normalized boxes coordinates for all queries, represented as
150
+ (center_x, width). These values are normalized in [0, 1],
151
+ relative to the size of each individual image (disregarding possible padding).
152
+ See PostProcess for information on how to retrieve the unnormalized bounding box.
153
+ - "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of
154
+ dictionnaries containing the two above keys for each decoder layer.
155
+ """
156
+
157
+
158
+ ## For discovering real negative samples
159
+ device = src_txt_mask.device
160
+ # import pdb; pdb.set_trace()
161
+ # if vid is not None: ## for demo (run_on_video/run.py)
162
+ # _count = [v.count('_') for v in vid]
163
+ # if self.args.dset_name == 'hl':
164
+ # _position_to_cut = [find_nth(v, '_', _count[i]-1) for i, v in enumerate(vid)]
165
+ # ori_vid = [v[:_position_to_cut[i]] for i, v in enumerate(vid)]
166
+ # else:
167
+ if vid is not None:
168
+ ori_vid = [v for v in vid]
169
+
170
+ if src_aud is not None:
171
+ src_vid = torch.cat([src_vid, src_aud], dim=2)
172
+
173
+ # --------------------------------
174
+ src_txt_list = []
175
+ src_txt_mask_list = []
176
+ for bs in range(src_txt.shape[0]):
177
+ idx = int(src_txt_mask[bs].sum().item())
178
+ src_txt_list.append(torch.cat((src_txt[bs, :idx, :], prompt_token[bs], src_txt[bs, idx:, :]), dim=0))
179
+ src_txt_mask_list.append(torch.cat((src_txt_mask[bs, :idx], torch.ones(1, dtype=torch.bfloat16).to(device), src_txt_mask[bs, idx:]), dim=0))
180
+
181
+ src_txt = torch.stack(src_txt_list, dim=0)
182
+ src_txt_mask = torch.stack(src_txt_mask_list, dim=0)
183
+ # --------------------------------
184
+
185
+ # src_txt = torch.cat((src_txt, prompt_token), dim=1)
186
+ # src_txt_mask = torch.cat((src_txt_mask, torch.zeros_like(prompt_token)), dim=1)
187
+
188
+ src_vid = self.input_vid_proj(src_vid) # [bsz,vlen,770] -> [bsz,vlen,256]
189
+ src_txt = self.input_txt_proj(src_txt) # [bsz,qlen,4096] -> [bsz,qlen, 256]
190
+
191
+ src_vid = src_vid + self.token_type_embeddings(torch.full_like(src_vid_mask.long(), 1)) # TODO
192
+ src_txt = src_txt + self.token_type_embeddings(torch.zeros_like(src_txt_mask.long()))
193
+
194
+ #
195
+ pos_vid = self.position_embed(src_vid, src_vid_mask).type(torch.bfloat16) # (bsz, L_vid, d)
196
+ pos_txt = self.txt_position_embed(src_txt) if self.use_txt_pos else torch.zeros_like(src_txt).type(torch.bfloat16) # (bsz, L_txt, d)
197
+
198
+ ### insert dummy token in front of txt
199
+ txt_dummy = self.dummy_rep_token.reshape([1, self.args.num_dummies, self.hidden_dim]).repeat(src_txt.shape[0], 1, 1) # [bsz, 45, 256]
200
+ src_txt_dummy = torch.cat([txt_dummy, src_txt], dim=1) # [bsz, L_txt+45, 256]
201
+ mask_txt = torch.tensor([[True] * self.args.num_dummies]).to(src_txt_mask.device).repeat(src_txt_mask.shape[0], 1)
202
+ src_txt_mask_dummy = torch.cat([mask_txt, src_txt_mask], dim=1) # [bsz, L_txt+45]
203
+
204
+ pos_dummy = self.dummy_rep_pos.reshape([1, self.args.num_dummies, self.hidden_dim]).repeat(pos_txt.shape[0], 1, 1).type(torch.bfloat16)
205
+ pos_txt_dummy = torch.cat([pos_dummy, pos_txt], dim=1)
206
+ src_txt_dummy = src_txt_dummy.permute(1, 0, 2) # (L, batch_size, d)
207
+ pos_txt_dummy = pos_txt_dummy.permute(1, 0, 2) # (L, batch_size, d)
208
+
209
+ memory = self.txtproj_encoder(src_txt_dummy, src_key_padding_mask=~(src_txt_mask_dummy.bool()), pos=pos_txt_dummy) # (L, batch_size, d)
210
+ dummy_token = memory[:self.args.num_dummies].permute(1, 0, 2)
211
+ pos_txt_dummy = pos_txt_dummy.permute(1, 0, 2) # (L, batch_size, d)
212
+
213
+ src_txt_dummy = torch.cat([dummy_token, src_txt], dim=1)
214
+ mask_txt_dummy = torch.tensor([[True]*self.args.num_dummies]).to(src_txt_mask.device).repeat(src_txt_mask.shape[0], 1)
215
+ src_txt_mask_dummy = torch.cat([mask_txt_dummy, src_txt_mask], dim=1)
216
+
217
+ # Input : Concat video, dummy, txt
218
+ src = torch.cat([src_vid, src_txt_dummy], dim=1) # (bsz, L_vid+L_txt, d)
219
+ mask = torch.cat([src_vid_mask, src_txt_mask_dummy], dim=1).bool() # (bsz, L_vid+L_txt)
220
+ pos = torch.cat([pos_vid, pos_txt_dummy], dim=1)
221
+
222
+ ### sentence token
223
+ smask_ = torch.tensor([[True]]).to(mask.device).repeat(src_txt_mask.shape[0], 1)
224
+ smask = torch.cat([smask_, src_txt_mask.bool()], dim=1)
225
+ ssrc_ = self.sent_rep_token.reshape([1, 1, self.hidden_dim]).repeat(src_txt.shape[0], 1, 1)
226
+ ssrc = torch.cat([ssrc_, src_txt], dim=1)
227
+ spos_ = self.sent_rep_pos.reshape([1, 1, self.hidden_dim]).repeat(pos_txt.shape[0], 1, 1)
228
+ spos = torch.cat([spos_, pos_txt], dim=1)
229
+ ### dummy sentence token
230
+ smaskd = torch.cat([smask_, mask_txt_dummy.bool()], dim=1)
231
+ ssrcd = torch.cat([ssrc_, dummy_token], dim=1)
232
+ sposd = torch.cat([spos_, pos_dummy], dim=1)
233
+
234
+ if targets is not None: # train
235
+ mmask_ = torch.tensor([[True]]).to(mask.device).repeat(src_vid_mask.shape[0], 1)
236
+ mmask = torch.cat([mmask_, src_vid_mask.bool()], dim=1) # [bsz, L_vid+1]
237
+ moment_mask_ = torch.clamp(targets["relevant_clips"], 0, 1).bool()
238
+ moment_mask = torch.cat([mmask_, moment_mask_], dim=1) # [bsz, L_vid+1]
239
+ # if moment_mask.shape[1] != 76:
240
+ # import pdb; pdb.set_trace()
241
+ mmask = mmask * moment_mask
242
+
243
+ msrc_ = self.moment_rep_token.reshape([1, 1, self.hidden_dim]).repeat(src_vid.shape[0], 1, 1)
244
+ msrc = torch.cat([msrc_, src_vid], dim=1)
245
+ mpos_ = self.moment_rep_pos.reshape([1, 1, self.hidden_dim]).repeat(pos_vid.shape[0], 1, 1)
246
+ mpos = torch.cat([mpos_, pos_vid], dim=1)
247
+
248
+ ### for Not moment token ####
249
+ nmmask_ = torch.tensor([[True]]).to(mask.device).repeat(src_vid_mask.shape[0], 1)
250
+ nmmask = torch.cat([nmmask_, src_vid_mask.bool()], dim=1)
251
+ nmoment_mask_ = ~(torch.clamp(targets["relevant_clips"], 0, 1).bool())
252
+ nmoment_mask = torch.cat([nmmask_, nmoment_mask_], dim=1)
253
+ nmmask = nmmask * nmoment_mask
254
+
255
+ nmsrc_ = self.moment_rep_token.reshape([1, 1, self.hidden_dim]).repeat(src_vid.shape[0], 1, 1)
256
+ nmsrc = torch.cat([nmsrc_, src_vid], dim=1)
257
+ nmpos_ = self.moment_rep_pos.reshape([1, 1, self.hidden_dim]).repeat(pos_vid.shape[0], 1, 1)
258
+ nmpos = torch.cat([nmpos_, pos_vid], dim=1)
259
+ ###########
260
+ else:
261
+ moment_mask_ = None
262
+
263
+ # for t2vidavg sal token
264
+ # import pdb; pdb.set_trace()
265
+ vidsrc_ = torch.zeros((len(src_vid), 1, self.hidden_dim), dtype=torch.bfloat16).to(device)
266
+ for i in range(len(src_vid)):
267
+ vidsrc_[i] = src_vid[i][:src_vid_mask.sum(1)[i].long()].mean(0).clone().detach()
268
+
269
+ video_length = src_vid.shape[1]
270
+ if targets is not None: ## train
271
+ ssrc = ssrc.permute(1, 0, 2) # (L, batch_size, d)
272
+ spos = spos.permute(1, 0, 2) # (L, batch_size, d)
273
+ smemory = self.scls_encoder(ssrc, src_key_padding_mask=~smask, pos=spos) # (L, batch_size, d)
274
+ sentence_txt, smemory_words = smemory[0], smemory[1:] # sentence_txt : (batch_size, d)
275
+
276
+ ssrcd = ssrcd.permute(1, 0, 2) # (L, batch_size, d)
277
+ sposd = sposd.permute(1, 0, 2) # (L, batch_size, d)
278
+ smemoryd = self.scls_encoder(ssrcd, src_key_padding_mask=~smaskd, pos=sposd) # (L, batch_size, d)
279
+ sentence_dummy, smemory_words_dummy = smemoryd[0], smemoryd[1:]
280
+
281
+ txt_dummy_proj = torch.cat([smemory_words_dummy, smemory_words], dim=0)
282
+
283
+ # import pdb; pdb.set_trace()
284
+ # print(src.dtype)
285
+ hs, reference, memory, memory_global, attn_weights, memory_moment, nmmemory_moment, mmemory_frames, nmmemory_frames = self.transformer(src, ~mask, self.query_embed.weight, pos, video_length=video_length, moment_idx=targets["relevant_clips"], msrc=msrc, mpos=mpos, mmask=~mmask, nmsrc=nmsrc, nmpos=nmpos, nmmask=~nmmask,
286
+ ctxtoken=vidsrc_, gtoken=self.global_rep_token, gpos=self.global_rep_pos, vlen=src_vid_mask.sum(1).long())
287
+ moment2txt_similarity = torch.matmul(mmemory_frames.permute(1, 0, 2), txt_dummy_proj.permute(1, 2, 0))
288
+ nmoment2txt_similarity = torch.matmul(nmmemory_frames.permute(1, 0, 2), txt_dummy_proj.permute(1, 2, 0))
289
+ else: ## inference
290
+ sentence_dummy, sentence_txt, moment2txt_similarity, nmoment2txt_similarity = None, None, None, None
291
+ hs, reference, memory, memory_global, attn_weights, memory_moment, nmmemory_moment, mmemory_frames, nmmemory_frames = self.transformer(src, ~mask, self.query_embed.weight, pos, video_length=video_length,
292
+ ctxtoken=vidsrc_, gtoken=self.global_rep_token, gpos=self.global_rep_pos, vlen=src_vid_mask.sum(1).long())
293
+ outputs_class = self.class_embed(hs) # (#layers, batch_size, #queries, #classes)
294
+ reference_before_sigmoid = inverse_sigmoid(reference)
295
+ tmp = self.span_embed(hs)
296
+ outputs_coord = tmp + reference_before_sigmoid
297
+ if self.span_loss_type == "l1":
298
+ outputs_coord = outputs_coord.sigmoid()
299
+ out = {'pred_logits': outputs_class[-1], 'pred_spans': outputs_coord[-1]}
300
+
301
+ txt_mem = memory[:, src_vid.shape[1]:] # (bsz, L_txt, d)
302
+ vid_mem = memory[:, :src_vid.shape[1]] # (bsz, L_vid, d)
303
+ if self.contrastive_align_loss:
304
+ proj_queries = F.normalize(self.contrastive_align_projection_query(hs), p=2, dim=-1)
305
+ proj_txt_mem = F.normalize(self.contrastive_align_projection_txt(txt_mem), p=2, dim=-1)
306
+ proj_vid_mem = F.normalize(self.contrastive_align_projection_vid(vid_mem), p=2, dim=-1)
307
+ out.update(dict(
308
+ proj_queries=proj_queries[-1],
309
+ proj_txt_mem=proj_txt_mem,
310
+ proj_vid_mem=proj_vid_mem
311
+ ))
312
+
313
+ if vid is not None: ## for demo (run_on_video/run.py)
314
+ ### Neg Pairs ###
315
+ neg_vid = ori_vid[1:] + ori_vid[:1]
316
+
317
+ real_neg_mask = torch.Tensor(element_wise_list_equal(ori_vid, neg_vid)).to(src_txt_dummy.device)
318
+ real_neg_mask = real_neg_mask.type(torch.bfloat16)
319
+
320
+ real_neg_mask = real_neg_mask == False
321
+
322
+ # import pdb; pdb.set_trace()
323
+ if real_neg_mask.sum() != 0:
324
+
325
+ src_txt_dummy_neg = torch.cat([src_txt_dummy[1:], src_txt_dummy[0:1]], dim=0)
326
+ src_txt_mask_dummy_neg = torch.cat([src_txt_mask_dummy[1:], src_txt_mask_dummy[0:1]], dim=0)
327
+ src_dummy_neg = torch.cat([src_vid, src_txt_dummy_neg], dim=1)
328
+ mask_dummy_neg = torch.cat([src_vid_mask, src_txt_mask_dummy_neg], dim=1).bool()
329
+ pos_neg = pos.clone() # since it does not use actual content
330
+
331
+ mask_dummy_neg = mask_dummy_neg[real_neg_mask]
332
+ src_dummy_neg = src_dummy_neg[real_neg_mask]
333
+ pos_neg = pos_neg[real_neg_mask]
334
+ src_txt_mask_dummy_neg = src_txt_mask_dummy_neg[real_neg_mask]
335
+
336
+ # import pdb; pdb.set_trace()
337
+ _, _, memory_neg, memory_global_neg, attn_weights_neg, _, _, _, _ = self.transformer(src_dummy_neg, ~mask_dummy_neg, self.query_embed.weight, pos_neg, video_length=video_length,
338
+ ctxtoken=vidsrc_[real_neg_mask], gtoken=self.global_rep_token, gpos=self.global_rep_pos, vlen=src_vid_mask[real_neg_mask].sum(1).long())
339
+ vid_mem_neg = memory_neg[:, :src_vid.shape[1]]
340
+ out["saliency_scores_neg"] = (torch.sum(self.saliency_proj1(vid_mem_neg) * self.saliency_proj2(memory_global_neg).unsqueeze(1), dim=-1) / np.sqrt(self.hidden_dim))
341
+ out["src_txt_mask_neg"] = src_txt_mask_dummy_neg
342
+
343
+ out["t2vattnvalues_neg"] = (attn_weights_neg[:, :, self.args.num_dummies:] * (src_txt_mask_dummy_neg[:, self.args.num_dummies:].unsqueeze(1).repeat(1, video_length, 1))).sum(2)
344
+ out["t2vattnvalues_neg"] = torch.clamp(out["t2vattnvalues_neg"], 0, 1)
345
+ else:
346
+ out["saliency_scores_neg"] = None
347
+ out["t2vattnvalues_neg"] = None
348
+ out["real_neg_mask"] = real_neg_mask
349
+ else:
350
+ out["saliency_scores_neg"] = None
351
+ out["t2vattnvalues_neg"] = None
352
+ out["real_neg_mask"] = None
353
+
354
+
355
+ out["saliency_scores"] = (torch.sum(self.saliency_proj1(vid_mem) * self.saliency_proj2(memory_global).unsqueeze(1), dim=-1) / np.sqrt(self.hidden_dim))
356
+ out["memory_moment"] = memory_moment
357
+ out["nmmemory_moment"] = nmmemory_moment
358
+
359
+ ## sentence token embeeded with text / dummy
360
+ out["sentence_txt"] = sentence_txt
361
+ out["sentence_dummy"] = sentence_dummy
362
+ out["moment2txt_similarity"] = moment2txt_similarity
363
+ out["nmoment2txt_similarity"] = nmoment2txt_similarity
364
+ out["cate_attn_weights"] = attn_weights
365
+ out["moment_mask"] = moment_mask_
366
+ out["txt_mask"] = src_txt_mask_dummy
367
+
368
+
369
+ out["t2vattnvalues"] = (attn_weights[:,:,self.args.num_dummies:] * (src_txt_mask.unsqueeze(1).repeat(1, video_length, 1))).sum(2) # (batch_size, L_vid, L_txt) / (batch_size, L_txt)
370
+ out["t2vattnvalues"] = torch.clamp(out["t2vattnvalues"], 0, 1)
371
+ out["dummy_tokens"] = dummy_token
372
+ out["global_rep_tokens"] = self.global_rep_token
373
+
374
+ # import pdb; pdb.set_trace()
375
+ if targets is not None:
376
+ out["src_vid"] = mmemory_frames.permute(1, 0, 2) * moment_mask_.unsqueeze(2) + nmmemory_frames.permute(1, 0, 2) * (~(moment_mask_.unsqueeze(2).bool())).bfloat16()
377
+ else:
378
+ out["src_vid"] = None
379
+
380
+ out["video_mask"] = src_vid_mask
381
+ if self.aux_loss:
382
+ # assert proj_queries and proj_txt_mem
383
+ out['aux_outputs'] = [
384
+ {'pred_logits': a, 'pred_spans': b} for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]
385
+ if self.contrastive_align_loss:
386
+ assert proj_queries is not None
387
+ for idx, d in enumerate(proj_queries[:-1]):
388
+ out['aux_outputs'][idx].update(dict(proj_queries=d, proj_txt_mem=proj_txt_mem))
389
+ return out
390
+
391
+ class SetCriterion(nn.Module):
392
+ """ This class computes the loss for DETR.
393
+ The process happens in two steps:
394
+ 1) we compute hungarian assignment between ground truth boxes and the outputs of the model
395
+ 2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
396
+ """
397
+
398
+ def __init__(self, matcher, weight_dict, eos_coef, losses, temperature, span_loss_type, max_v_l,
399
+ saliency_margin=1, use_matcher=True, args=None):
400
+ """ Create the criterion.
401
+ Parameters:
402
+ matcher: module able to compute a matching between targets and proposals
403
+ weight_dict: dict containing as key the names of the losses and as values their relative weight.
404
+ eos_coef: relative classification weight applied to the no-object category
405
+ losses: list of all the losses to be applied. See get_loss for list of available losses.
406
+ temperature: float, temperature for NCE loss
407
+ span_loss_type: str, [l1, ce]
408
+ max_v_l: int,
409
+ saliency_margin: float
410
+ """
411
+ super().__init__()
412
+ self.args=args
413
+ self.matcher = matcher
414
+ self.weight_dict = weight_dict
415
+ self.losses = losses
416
+ self.temperature = temperature
417
+ self.span_loss_type = span_loss_type
418
+ self.max_v_l = max_v_l
419
+ self.saliency_margin = saliency_margin
420
+
421
+ # foreground and background classification
422
+ self.foreground_label = 0
423
+ self.background_label = 1
424
+ self.eos_coef = eos_coef
425
+ empty_weight = torch.ones(2)
426
+ empty_weight[-1] = self.eos_coef # lower weight for background (index 1, foreground index 0)
427
+ self.register_buffer('empty_weight', empty_weight)
428
+
429
+ # for tvsum,
430
+ self.use_matcher = use_matcher
431
+
432
+ # moment sentence contrastive
433
+ self.criterion = torch.nn.CrossEntropyLoss()#.to(self.args.device)
434
+ self.l2_criterion = torch.nn.MSELoss()#.to(self.args.device)
435
+ self.kld_criterion = torch.nn.KLDivLoss(reduction='none')#.to(self.args.device)
436
+ self.bce_criterion = nn.BCELoss(reduction='none')
437
+
438
+ def loss_spans(self, outputs, targets, indices):
439
+ """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss
440
+ targets dicts must contain the key "spans" containing a tensor of dim [nb_tgt_spans, 2]
441
+ The target spans are expected in format (center_x, w), normalized by the image size.
442
+ """
443
+ assert 'pred_spans' in outputs
444
+ targets = targets["span_labels"]
445
+ idx = self._get_src_permutation_idx(indices)
446
+ src_spans = outputs['pred_spans'][idx] # (#spans, max_v_l * 2)
447
+ tgt_spans = torch.cat([t['spans'][i] for t, (_, i) in zip(targets, indices)], dim=0) # (#spans, 2)
448
+ if self.span_loss_type == "l1":
449
+ loss_span = F.l1_loss(src_spans, tgt_spans, reduction='none')
450
+ loss_giou = 1 - torch.diag(generalized_temporal_iou(span_cxw_to_xx(src_spans), span_cxw_to_xx(tgt_spans)))
451
+ else: # ce
452
+ n_spans = src_spans.shape[0]
453
+ src_spans = src_spans.view(n_spans, 2, self.max_v_l).transpose(1, 2)
454
+ loss_span = F.cross_entropy(src_spans, tgt_spans, reduction='none')
455
+ loss_giou = loss_span.new_zeros([1])
456
+
457
+ losses = {}
458
+ losses['loss_span'] = loss_span.mean()
459
+ losses['loss_giou'] = loss_giou.mean()
460
+ return losses
461
+
462
+ def loss_labels(self, outputs, targets, indices, log=True):
463
+ """Classification loss (NLL)
464
+ targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
465
+ """
466
+ # TODO add foreground and background classifier. use all non-matched as background.
467
+ assert 'pred_logits' in outputs
468
+ src_logits = outputs['pred_logits'] # (batch_size, #queries, #classes=2)
469
+ # idx is a tuple of two 1D tensors (batch_idx, src_idx), of the same length == #objects in batch
470
+ idx = self._get_src_permutation_idx(indices)
471
+ target_classes = torch.full(src_logits.shape[:2], self.background_label,
472
+ dtype=torch.int64, device=src_logits.device) # (batch_size, #queries)
473
+ target_classes[idx] = self.foreground_label
474
+
475
+ loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight, reduction="none")
476
+ losses = {'loss_label': loss_ce.mean()}
477
+
478
+ if log:
479
+ # TODO this should probably be a separate loss, not hacked in this one here
480
+ losses['class_error'] = 100 - accuracy(src_logits[idx], self.foreground_label)[0]
481
+ return losses
482
+
483
+ def loss_saliency(self, outputs, targets, indices, log=True):
484
+ """higher scores for positive clips"""
485
+ if "saliency_pos_labels" not in targets:
486
+ return {"loss_saliency": 0}
487
+
488
+ # Neg pair loss
489
+ if outputs["saliency_scores_neg"] is not None: ## When batch size is not 1 (negative pair exists)
490
+ vid_token_mask = outputs["video_mask"]
491
+ real_neg_mask = outputs["real_neg_mask"]
492
+ saliency_scores_neg = outputs["saliency_scores_neg"].clone() # (N, L)
493
+ loss_neg_pair = (- torch.log(1. - torch.sigmoid(saliency_scores_neg)) * (vid_token_mask[real_neg_mask])).sum(dim=1).mean()
494
+
495
+ saliency_scores = outputs["saliency_scores"].clone() # (N, L)
496
+ saliency_contrast_label = targets["saliency_all_labels"]
497
+
498
+ # real neg
499
+ realneg_saliency_scores = torch.cat([saliency_scores[real_neg_mask], saliency_scores_neg], dim=1)
500
+ realneg_saliency_contrast_label = torch.cat([saliency_contrast_label[real_neg_mask], torch.zeros_like(saliency_contrast_label)[real_neg_mask]], dim=1)
501
+ realneg_vid_token_mask = vid_token_mask[real_neg_mask].repeat([1, 2])
502
+ realneg_saliency_scores = realneg_vid_token_mask * realneg_saliency_scores + (1. - realneg_vid_token_mask) * -1e+3
503
+
504
+ tau = 0.5
505
+ loss_rank_contrastive = 0.
506
+ for rand_idx in range(1, 12):
507
+ drop_mask = ~(realneg_saliency_contrast_label > 100) # no drop
508
+ pos_mask = (realneg_saliency_contrast_label >= rand_idx) # positive when equal or higher than rand_idx
509
+ if torch.sum(pos_mask) == 0: # no positive sample
510
+ continue
511
+ else:
512
+ batch_drop_mask = torch.sum(pos_mask, dim=1) > 0 # negative sample indicator
513
+
514
+ # drop higher ranks
515
+ cur_saliency_scores = realneg_saliency_scores * drop_mask / tau + ~drop_mask * -1e+3
516
+ # numerical stability
517
+ logits = cur_saliency_scores - torch.max(cur_saliency_scores, dim=1, keepdim=True)[0]
518
+ # softmax
519
+ exp_logits = torch.exp(logits)
520
+ log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True) + 1e-6)
521
+
522
+ mean_log_prob_pos = (pos_mask * log_prob * realneg_vid_token_mask).sum(1) / (pos_mask.sum(1) + 1e-6)
523
+ loss = - mean_log_prob_pos * batch_drop_mask
524
+ loss_rank_contrastive = loss_rank_contrastive + loss.mean()
525
+ loss_rank_contrastive = loss_rank_contrastive / 12
526
+
527
+ false_neg_mask = ~(real_neg_mask)
528
+ if false_neg_mask.sum() != 0:
529
+ if false_neg_mask.sum() == 1:
530
+ falseneg_saliency_scores = saliency_scores[false_neg_mask].unsqueeze(0)
531
+ falseneg_saliency_contrast_label = saliency_contrast_label[false_neg_mask].unsqueeze(0)
532
+ falseneg_vid_token_mask = vid_token_mask[false_neg_mask].unsqueeze(0)
533
+ falseneg_saliency_scores = falseneg_vid_token_mask * falseneg_saliency_scores + (1. - falseneg_vid_token_mask) * -1e+3
534
+ else:
535
+ falseneg_saliency_scores = saliency_scores[false_neg_mask]
536
+ falseneg_saliency_contrast_label = saliency_contrast_label[false_neg_mask]
537
+ falseneg_vid_token_mask = vid_token_mask[false_neg_mask]
538
+ falseneg_saliency_scores = falseneg_vid_token_mask * falseneg_saliency_scores + (1. - falseneg_vid_token_mask) * -1e+3
539
+
540
+ tau = 0.5
541
+ falseneg_loss_rank_contrastive = 0.
542
+ for rand_idx in range(1, 12):
543
+ drop_mask = ~(falseneg_saliency_contrast_label > 100) # no drop
544
+ pos_mask = (falseneg_saliency_contrast_label >= rand_idx) # positive when equal or higher than rand_idx
545
+ if torch.sum(pos_mask) == 0: # no positive sample
546
+ continue
547
+ else:
548
+ batch_drop_mask = torch.sum(pos_mask, dim=1) > 0 # negative sample indicator
549
+
550
+ # drop higher ranks
551
+ cur_saliency_scores = falseneg_saliency_scores * drop_mask / tau + ~drop_mask * -1e+3
552
+ # numerical stability
553
+ logits = cur_saliency_scores - torch.max(cur_saliency_scores, dim=1, keepdim=True)[0]
554
+ # softmax
555
+ exp_logits = torch.exp(logits)
556
+ log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True) + 1e-6)
557
+
558
+ mean_log_prob_pos = (pos_mask * log_prob * falseneg_vid_token_mask).sum(1) / (pos_mask.sum(1) + 1e-6)
559
+ loss = - mean_log_prob_pos * batch_drop_mask
560
+ falseneg_loss_rank_contrastive = falseneg_loss_rank_contrastive + loss.mean()
561
+ falseneg_loss_rank_contrastive = falseneg_loss_rank_contrastive / 12
562
+ loss_rank_contrastive += falseneg_loss_rank_contrastive
563
+
564
+ saliency_scores = outputs["saliency_scores"] # (N, L)
565
+ pos_indices = targets["saliency_pos_labels"] # (N, #pairs)
566
+ neg_indices = targets["saliency_neg_labels"] # (N, #pairs)
567
+ num_pairs = pos_indices.shape[1] # typically 2 or 4
568
+ batch_indices = torch.arange(len(saliency_scores)).to(saliency_scores.device)
569
+ pos_scores = torch.stack(
570
+ [saliency_scores[batch_indices, pos_indices[:, col_idx]] for col_idx in range(num_pairs)], dim=1)
571
+ neg_scores = torch.stack(
572
+ [saliency_scores[batch_indices, neg_indices[:, col_idx]] for col_idx in range(num_pairs)], dim=1)
573
+ loss_saliency = torch.clamp(self.saliency_margin + neg_scores - pos_scores, min=0).sum() \
574
+ / (len(pos_scores) * num_pairs) * 2 # * 2 to keep the loss the same scale
575
+
576
+ # if self.args.dset_name in ['youtube_uni']:
577
+ # loss_saliency = loss_saliency + loss_rank_contrastive + loss_neg_pair * 0.
578
+ # else:
579
+ loss_saliency = loss_saliency + loss_rank_contrastive + loss_neg_pair
580
+
581
+ ########### Saliency loss to t2v attn weights ##############
582
+ """higher scores for positive clips"""
583
+ vid_token_mask = outputs["video_mask"]
584
+ # Neg pair loss
585
+
586
+ if outputs["t2vattnvalues_neg"] is not None:
587
+ saliency_scores_neg = outputs["t2vattnvalues_neg"].clone() # (N, L)
588
+ loss_neg_pair_attn = (- torch.log(1. - saliency_scores_neg) * (vid_token_mask[real_neg_mask])).sum(dim=1).mean()
589
+
590
+ saliency_scores = outputs["t2vattnvalues"].clone() # (N, L)
591
+ saliency_contrast_label = targets["saliency_all_labels"]
592
+
593
+ # real neg
594
+ realneg_saliency_scores = torch.cat([saliency_scores[real_neg_mask], saliency_scores_neg], dim=1)
595
+ realneg_saliency_contrast_label = torch.cat(
596
+ [saliency_contrast_label[real_neg_mask], torch.zeros_like(saliency_contrast_label)[real_neg_mask]], dim=1)
597
+ realneg_vid_token_mask = vid_token_mask[real_neg_mask].repeat([1, 2])
598
+ realneg_saliency_scores = realneg_vid_token_mask * realneg_saliency_scores + (
599
+ 1. - realneg_vid_token_mask) * -1e+3
600
+
601
+ tau = 0.5
602
+ loss_rank_contrastive_attn = 0.
603
+ for rand_idx in range(1, 12):
604
+ drop_mask = ~(realneg_saliency_contrast_label > 100) # no drop
605
+ pos_mask = (realneg_saliency_contrast_label >= rand_idx) # positive when equal or higher than rand_idx
606
+ if torch.sum(pos_mask) == 0: # no positive sample
607
+ continue
608
+ else:
609
+ batch_drop_mask = torch.sum(pos_mask, dim=1) > 0 # negative sample indicator
610
+
611
+ # drop higher ranks
612
+ cur_saliency_scores = realneg_saliency_scores * drop_mask / tau + ~drop_mask * -1e+3
613
+ # numerical stability
614
+ logits = cur_saliency_scores - torch.max(cur_saliency_scores, dim=1, keepdim=True)[0]
615
+ # softmax
616
+ exp_logits = torch.exp(logits)
617
+ log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True) + 1e-6)
618
+
619
+ mean_log_prob_pos = (pos_mask * log_prob * realneg_vid_token_mask).sum(1) / (pos_mask.sum(1) + 1e-6)
620
+ loss = - mean_log_prob_pos * batch_drop_mask
621
+ loss_rank_contrastive_attn = loss_rank_contrastive_attn + loss.mean()
622
+ loss_rank_contrastive_attn = loss_rank_contrastive_attn / 12
623
+
624
+ false_neg_mask = ~(real_neg_mask)
625
+ if false_neg_mask.sum() != 0:
626
+ if false_neg_mask.sum() == 1:
627
+ falseneg_saliency_scores = saliency_scores[false_neg_mask].unsqueeze(0)
628
+ falseneg_saliency_contrast_label = saliency_contrast_label[false_neg_mask].unsqueeze(0)
629
+ falseneg_vid_token_mask = vid_token_mask[false_neg_mask].unsqueeze(0)
630
+ falseneg_saliency_scores = falseneg_vid_token_mask * falseneg_saliency_scores + (1. - falseneg_vid_token_mask) * -1e+3
631
+ else:
632
+ falseneg_saliency_scores = saliency_scores[false_neg_mask]
633
+ falseneg_saliency_contrast_label = saliency_contrast_label[false_neg_mask]
634
+ falseneg_vid_token_mask = vid_token_mask[false_neg_mask]
635
+ falseneg_saliency_scores = falseneg_vid_token_mask * falseneg_saliency_scores + (1. - falseneg_vid_token_mask) * -1e+3
636
+
637
+ tau = 0.5
638
+ falseneg_loss_rank_contrastive = 0.
639
+ for rand_idx in range(1, 12):
640
+ drop_mask = ~(falseneg_saliency_contrast_label > 100) # no drop
641
+ pos_mask = (falseneg_saliency_contrast_label >= rand_idx) # positive when equal or higher than rand_idx
642
+ if torch.sum(pos_mask) == 0: # no positive sample
643
+ continue
644
+ else:
645
+ batch_drop_mask = torch.sum(pos_mask, dim=1) > 0 # negative sample indicator
646
+
647
+ # drop higher ranks
648
+ cur_saliency_scores = falseneg_saliency_scores * drop_mask / tau + ~drop_mask * -1e+3
649
+ # numerical stability
650
+ logits = cur_saliency_scores - torch.max(cur_saliency_scores, dim=1, keepdim=True)[0]
651
+ # softmax
652
+ exp_logits = torch.exp(logits)
653
+ log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True) + 1e-6)
654
+
655
+ mean_log_prob_pos = (pos_mask * log_prob * falseneg_vid_token_mask).sum(1) / (pos_mask.sum(1) + 1e-6)
656
+ loss = - mean_log_prob_pos * batch_drop_mask
657
+ falseneg_loss_rank_contrastive = falseneg_loss_rank_contrastive + loss.mean()
658
+ falseneg_loss_rank_contrastive = falseneg_loss_rank_contrastive / 12
659
+ loss_rank_contrastive += falseneg_loss_rank_contrastive
660
+
661
+ saliency_scores = outputs["t2vattnvalues"] # (N, L)
662
+ pos_indices = targets["saliency_pos_labels"] # (N, #pairs)
663
+ neg_indices = targets["saliency_neg_labels"] # (N, #pairs)
664
+ num_pairs = pos_indices.shape[1] # typically 2 or 4
665
+ batch_indices = torch.arange(len(saliency_scores)).to(saliency_scores.device)
666
+ pos_scores = torch.stack(
667
+ [saliency_scores[batch_indices, pos_indices[:, col_idx]] for col_idx in range(num_pairs)], dim=1)
668
+ neg_scores = torch.stack(
669
+ [saliency_scores[batch_indices, neg_indices[:, col_idx]] for col_idx in range(num_pairs)], dim=1)
670
+ loss_saliency_attn = torch.clamp(self.saliency_margin + neg_scores - pos_scores, min=0).sum() \
671
+ / (len(pos_scores) * num_pairs) * 2 # * 2 to keep the loss the same scale
672
+
673
+ saliency_binary_label = torch.clamp(targets["saliency_all_labels"], 0, 1)
674
+ logits = saliency_scores.reshape(-1)
675
+ labels_x = saliency_binary_label.reshape(-1)
676
+ BCEcriterion = nn.BCELoss()
677
+ bceloss = BCEcriterion(logits, labels_x)
678
+
679
+ # if self.args.dset_name in ['youtube_uni']:
680
+ # loss_saliency_attn = loss_rank_contrastive_attn + bceloss + loss_neg_pair_attn * 0 + loss_saliency_attn
681
+ # else:
682
+ loss_saliency_attn = loss_rank_contrastive_attn + bceloss + loss_neg_pair_attn + loss_saliency_attn
683
+
684
+ loss_saliency += (loss_saliency_attn * self.args.lw_wattn)
685
+
686
+ else: ## when batch size == 1
687
+ vid_token_mask = outputs["video_mask"]
688
+ saliency_scores = outputs["saliency_scores"].clone() # (N, L)
689
+ saliency_contrast_label = targets["saliency_all_labels"]
690
+
691
+ saliency_scores = vid_token_mask * saliency_scores + (1. - vid_token_mask) * -1e+3
692
+
693
+ tau = 0.5
694
+ loss_rank_contrastive = 0.
695
+ for rand_idx in range(1, 12):
696
+ drop_mask = ~(saliency_contrast_label > 100) # no drop
697
+ pos_mask = (saliency_contrast_label >= rand_idx) # positive when equal or higher than rand_idx
698
+ if torch.sum(pos_mask) == 0: # no positive sample
699
+ continue
700
+ else:
701
+ batch_drop_mask = torch.sum(pos_mask, dim=1) > 0 # negative sample indicator
702
+
703
+ # drop higher ranks
704
+ cur_saliency_scores = saliency_scores * drop_mask / tau + ~drop_mask * -1e+3
705
+ # numerical stability
706
+ logits = cur_saliency_scores - torch.max(cur_saliency_scores, dim=1, keepdim=True)[0]
707
+ # softmax
708
+ exp_logits = torch.exp(logits)
709
+ log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True) + 1e-6)
710
+
711
+ mean_log_prob_pos = (pos_mask * log_prob * vid_token_mask).sum(1) / (pos_mask.sum(1) + 1e-6)
712
+ loss = - mean_log_prob_pos * batch_drop_mask
713
+ loss_rank_contrastive = loss_rank_contrastive + loss.mean()
714
+ loss_rank_contrastive = loss_rank_contrastive / 12
715
+
716
+ saliency_scores = outputs["saliency_scores"] # (N, L)
717
+ pos_indices = targets["saliency_pos_labels"] # (N, #pairs)
718
+ neg_indices = targets["saliency_neg_labels"] # (N, #pairs)
719
+ num_pairs = pos_indices.shape[1] # typically 2 or 4
720
+ batch_indices = torch.arange(len(saliency_scores)).to(saliency_scores.device)
721
+ pos_scores = torch.stack(
722
+ [saliency_scores[batch_indices, pos_indices[:, col_idx]] for col_idx in range(num_pairs)], dim=1)
723
+ neg_scores = torch.stack(
724
+ [saliency_scores[batch_indices, neg_indices[:, col_idx]] for col_idx in range(num_pairs)], dim=1)
725
+ loss_saliency = torch.clamp(self.saliency_margin + neg_scores - pos_scores, min=0).sum() \
726
+ / (len(pos_scores) * num_pairs) * 2 # * 2 to keep the loss the same scale
727
+
728
+ loss_saliency = loss_saliency + loss_rank_contrastive
729
+ ########### Saliency loss to t2v attn weights ##############
730
+ """higher scores for positive clips"""
731
+ vid_token_mask = outputs["video_mask"]
732
+ saliency_scores = outputs["t2vattnvalues"].clone() # (N, L)
733
+ saliency_contrast_label = targets["saliency_all_labels"]
734
+
735
+ saliency_scores = vid_token_mask * saliency_scores + (1. - vid_token_mask) * -1e+3
736
+
737
+ tau = 0.5
738
+ loss_rank_contrastive = 0.
739
+ for rand_idx in range(1, 12):
740
+ drop_mask = ~(saliency_contrast_label > 100) # no drop
741
+ pos_mask = (saliency_contrast_label >= rand_idx) # positive when equal or higher than rand_idx
742
+ if torch.sum(pos_mask) == 0: # no positive sample
743
+ continue
744
+ else:
745
+ batch_drop_mask = torch.sum(pos_mask, dim=1) > 0 # negative sample indicator
746
+
747
+ # drop higher ranks
748
+ cur_saliency_scores = saliency_scores * drop_mask / tau + ~drop_mask * -1e+3
749
+ # numerical stability
750
+ logits = cur_saliency_scores - torch.max(cur_saliency_scores, dim=1, keepdim=True)[0]
751
+ # softmax
752
+ exp_logits = torch.exp(logits)
753
+ log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True) + 1e-6)
754
+
755
+ mean_log_prob_pos = (pos_mask * log_prob * vid_token_mask).sum(1) / (pos_mask.sum(1) + 1e-6)
756
+ loss = - mean_log_prob_pos * batch_drop_mask
757
+ loss_rank_contrastive = loss_rank_contrastive + loss.mean()
758
+ loss_rank_contrastive_attn = loss_rank_contrastive / 12
759
+
760
+ saliency_scores = outputs["t2vattnvalues"] # (N, L)
761
+ pos_indices = targets["saliency_pos_labels"] # (N, #pairs)
762
+ neg_indices = targets["saliency_neg_labels"] # (N, #pairs)
763
+ num_pairs = pos_indices.shape[1] # typically 2 or 4
764
+ batch_indices = torch.arange(len(saliency_scores)).to(saliency_scores.device)
765
+ pos_scores = torch.stack(
766
+ [saliency_scores[batch_indices, pos_indices[:, col_idx]] for col_idx in range(num_pairs)], dim=1)
767
+ neg_scores = torch.stack(
768
+ [saliency_scores[batch_indices, neg_indices[:, col_idx]] for col_idx in range(num_pairs)], dim=1)
769
+ loss_saliency_attn = torch.clamp(self.saliency_margin + neg_scores - pos_scores, min=0).sum() \
770
+ / (len(pos_scores) * num_pairs) * 2 # * 2 to keep the loss the same scale
771
+ saliency_binary_label = torch.clamp(targets["saliency_all_labels"], 0, 1)
772
+ logits = saliency_scores.reshape(-1)
773
+ labels_x = saliency_binary_label.reshape(-1)
774
+ BCEcriterion = nn.BCELoss()
775
+ bceloss = BCEcriterion(logits, labels_x)
776
+
777
+ loss_saliency_attn = loss_rank_contrastive_attn + bceloss + loss_saliency_attn
778
+ loss_saliency += (loss_saliency_attn * self.args.lw_wattn)
779
+ return {"loss_saliency": loss_saliency}
780
+
781
+ def loss_contrastive_moment_sentence(self, outputs, targets, indices, log=True):
782
+ if outputs["memory_moment"] is not None:
783
+ moment_token = outputs["memory_moment"]
784
+ nmmemory_moment = outputs["nmmemory_moment"]
785
+ sentence_token = outputs["sentence_txt"].squeeze(1)
786
+ sentence_dummy = outputs["sentence_dummy"].squeeze(1) # b, 1, d
787
+
788
+ moment_logits = F.normalize(moment_token, dim=1)
789
+ nmoment_logits = F.normalize(nmmemory_moment, dim=1)
790
+ sentence_logits = F.normalize(sentence_token, dim=1)
791
+ dummy_logits = F.normalize(sentence_dummy, dim=1)
792
+ # import pdb; pdb.set_trace()
793
+
794
+ similarity_matrix = torch.matmul(moment_logits, sentence_logits.T) # B B
795
+ nsimilarity_matrix = torch.matmul(nmoment_logits, sentence_logits.T) # B B
796
+ similarity_matrix = torch.cat([similarity_matrix, nsimilarity_matrix], dim=1)
797
+ labels = torch.eye(similarity_matrix.shape[0]).to(sentence_logits.device)
798
+ nlabels = torch.zeros_like(nsimilarity_matrix).to(sentence_logits.device)
799
+ labels = torch.cat([labels, nlabels], dim=1).max(dim=1)[1]
800
+
801
+ loss_ms_align = self.criterion(similarity_matrix, labels)
802
+
803
+ dummy_similarity_matrix = torch.matmul(moment_logits, dummy_logits.T)
804
+ dummy_nsimilarity_matrix = torch.matmul(nmoment_logits, dummy_logits.T)
805
+ dummy_similarity_matrix = torch.cat([dummy_similarity_matrix, dummy_nsimilarity_matrix], dim=1)
806
+ dummy_labels = (~(torch.eye(similarity_matrix.shape[0]).to(sentence_logits.device).bool())).float()
807
+ dummy_nlabels = torch.ones_like(nsimilarity_matrix).to(sentence_logits.device)
808
+ dummy_labels = torch.cat([dummy_labels, dummy_nlabels], dim=1).max(dim=1)[1]
809
+
810
+ dummy_loss_ms_align = self.criterion(dummy_similarity_matrix, dummy_labels)
811
+ loss_ms_align += dummy_loss_ms_align
812
+ video_mask = outputs['video_mask']
813
+ src_vid = outputs['src_vid'] # [bsz, L_vid, D_vid]
814
+ moment_mask_ = torch.clamp(targets["relevant_clips"], 0, 1)
815
+
816
+ momtokcls_pred = torch.matmul(moment_token.unsqueeze(1), src_vid.permute(0, 2, 1)) # bsz 1 L_vid
817
+ momtokcls_label = moment_mask_
818
+ momtokcls_logit = torch.sigmoid(momtokcls_pred)
819
+ loss_ms_align += (self.bce_criterion(momtokcls_logit.reshape(-1), momtokcls_label.reshape(-1)) * video_mask.reshape(-1)).mean()
820
+
821
+ else:
822
+ loss_ms_align = 0.
823
+ return {"loss_ms_align": loss_ms_align}
824
+ #
825
+
826
+ def loss_moment2txt_sim_distill(self, outputs, targets, indices, log=True):
827
+ if outputs["moment2txt_similarity"] is not None:
828
+ moment2txt_similarity = outputs["moment2txt_similarity"] # bsz L_clip 22
829
+ moment_mask = outputs["moment_mask"].int() # bsz L_clip 1
830
+ txt_mask = outputs["txt_mask"].unsqueeze(1).repeat(1, outputs["cate_attn_weights"].size(1), 1) # bsz l_t
831
+
832
+ attn_weights = outputs["cate_attn_weights"] # bsz L_clip 22
833
+ b, L_vid, L_txt = attn_weights.size()
834
+ loss_distill = self.kld_criterion(
835
+ torch.log(attn_weights + 1e-6).reshape(b * L_vid, -1),
836
+ torch.softmax(moment2txt_similarity, dim=-1).clone().detach().reshape(b * L_vid, -1)).mean(1) * moment_mask.reshape(-1)
837
+ loss_distill = loss_distill.sum() / moment_mask.sum()
838
+
839
+ else:
840
+ loss_distill = 0.
841
+ return {"loss_distill": loss_distill}
842
+
843
+ def loss_orthogonal_dummy(self, outputs, targets, indices, log=True):
844
+ dummy_tokens = outputs["dummy_tokens"] # (n_dum, dim)
845
+ if dummy_tokens.size(1) != 1:
846
+ dummy_tokens_norm = dummy_tokens / dummy_tokens.norm(dim=2)[:, :, None]
847
+ dummy_tokens_sim = torch.matmul(dummy_tokens_norm, dummy_tokens_norm.permute(0, 2, 1).detach())
848
+ for i in range(len(dummy_tokens_sim)):
849
+ dummy_tokens_sim[i].fill_diagonal_(0)
850
+ loss_dummy_ortho = dummy_tokens_sim.abs().mean()
851
+ else:
852
+ loss_dummy_ortho=0.
853
+ global_tokens = outputs["global_rep_tokens"]
854
+
855
+ global_tokens_norm = global_tokens / global_tokens.norm(dim=1)[:, None]
856
+ global_tokens_sim = torch.matmul(global_tokens_norm, global_tokens_norm.permute(1, 0).detach())
857
+ for i in range(len(global_tokens_sim)):
858
+ global_tokens_sim.fill_diagonal_(0)
859
+ loss_dummy_ortho += global_tokens_sim.abs().mean()
860
+ return {"loss_orthogonal_dummy": loss_dummy_ortho}
861
+
862
+ def loss_contrastive_align(self, outputs, targets, indices, log=True):
863
+ """encourage higher scores between matched query span and input text"""
864
+ normalized_text_embed = outputs["proj_txt_mem"] # (bsz, #tokens, d) text tokens
865
+ normalized_img_embed = outputs["proj_queries"] # (bsz, #queries, d)
866
+ logits = torch.einsum(
867
+ "bmd,bnd->bmn", normalized_img_embed, normalized_text_embed) # (bsz, #queries, #tokens)
868
+ logits = logits.sum(2) / self.temperature # (bsz, #queries)
869
+ idx = self._get_src_permutation_idx(indices)
870
+ positive_map = torch.zeros_like(logits, dtype=torch.bool)
871
+ positive_map[idx] = True
872
+ positive_logits = logits.masked_fill(~positive_map, 0)
873
+
874
+ pos_term = positive_logits.sum(1) # (bsz, )
875
+ num_pos = positive_map.sum(1) # (bsz, )
876
+ neg_term = logits.logsumexp(1) # (bsz, )
877
+ loss_nce = - pos_term / num_pos + neg_term # (bsz, )
878
+ losses = {"loss_contrastive_align": loss_nce.mean()}
879
+ return losses
880
+
881
+ def loss_contrastive_align_vid_txt(self, outputs, targets, indices, log=True):
882
+ """encourage higher scores between matched query span and input text"""
883
+ normalized_text_embed = outputs["proj_txt_mem"] # (bsz, #tokens, d) text tokens
884
+ normalized_img_embed = outputs["proj_queries"] # (bsz, #queries, d)
885
+ logits = torch.einsum(
886
+ "bmd,bnd->bmn", normalized_img_embed, normalized_text_embed) # (bsz, #queries, #tokens)
887
+ logits = logits.sum(2) / self.temperature # (bsz, #queries)
888
+ idx = self._get_src_permutation_idx(indices)
889
+ positive_map = torch.zeros_like(logits, dtype=torch.bool)
890
+ positive_map[idx] = True
891
+ positive_logits = logits.masked_fill(~positive_map, 0)
892
+
893
+ pos_term = positive_logits.sum(1) # (bsz, )
894
+ num_pos = positive_map.sum(1) # (bsz, )
895
+ neg_term = logits.logsumexp(1) # (bsz, )
896
+ loss_nce = - pos_term / num_pos + neg_term # (bsz, )
897
+ losses = {"loss_contrastive_align": loss_nce.mean()}
898
+ return losses
899
+
900
+ def _get_src_permutation_idx(self, indices):
901
+ # permute predictions following indices
902
+ batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
903
+ src_idx = torch.cat([src for (src, _) in indices])
904
+ return batch_idx, src_idx # two 1D tensors of the same length
905
+
906
+ def _get_tgt_permutation_idx(self, indices):
907
+ # permute targets following indices
908
+ batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
909
+ tgt_idx = torch.cat([tgt for (_, tgt) in indices])
910
+ return batch_idx, tgt_idx
911
+
912
+ def get_loss(self, loss, outputs, targets, indices, **kwargs):
913
+ loss_map = {
914
+ "spans": self.loss_spans,
915
+ "labels": self.loss_labels,
916
+ "contrastive_align": self.loss_contrastive_align,
917
+ "saliency": self.loss_saliency,
918
+ "ms_align": self.loss_contrastive_moment_sentence,
919
+ "distill": self.loss_moment2txt_sim_distill,
920
+ "orthogonal_dummy":self.loss_orthogonal_dummy
921
+ }
922
+ assert loss in loss_map, f'do you really want to compute {loss} loss?'
923
+ return loss_map[loss](outputs, targets, indices, **kwargs)
924
+
925
+ def forward(self, outputs, targets):
926
+ """ This performs the loss computation.
927
+ Parameters:
928
+ outputs: dict of tensors, see the output specification of the model for the format
929
+ targets: list of dicts, such that len(targets) == batch_size.
930
+ The expected keys in each dict depends on the losses applied, see each loss' doc
931
+ """
932
+ outputs_without_aux = {k: v for k, v in outputs.items() if k != 'aux_outputs'}
933
+
934
+ # Retrieve the matching between the outputs of the last layer and the targets
935
+ # list(tuples), each tuple is (pred_span_indices, tgt_span_indices)
936
+
937
+ # only for HL, do not use matcher
938
+ if self.use_matcher:
939
+ # import pdb; pdb.set_trace()
940
+ indices = self.matcher(outputs_without_aux, targets)
941
+ losses_target = self.losses
942
+ else:
943
+ indices = None
944
+ losses_target = ["saliency"]
945
+
946
+ # Compute all the requested losses
947
+ losses = {}
948
+ for loss in losses_target:
949
+ losses.update(self.get_loss(loss, outputs, targets, indices))
950
+
951
+ # In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
952
+ if 'aux_outputs' in outputs:
953
+ for i, aux_outputs in enumerate(outputs['aux_outputs']):
954
+ # indices = self.matcher(aux_outputs, targets)
955
+ if self.use_matcher:
956
+ indices = self.matcher(aux_outputs, targets)
957
+ losses_target = self.losses
958
+ else:
959
+ indices = None
960
+ losses_target = ["saliency", "ms_align", "distill", "orthogonal_dummy"]
961
+ for loss in losses_target:
962
+ if "saliency" == loss: # skip as it is only in the top layer
963
+ continue
964
+ if "ms_align" == loss:
965
+ continue
966
+ if "distill" == loss:
967
+ continue
968
+ if "orthogonal_dummy" == loss:
969
+ continue
970
+ kwargs = {}
971
+ l_dict = self.get_loss(loss, aux_outputs, targets, indices, **kwargs)
972
+ l_dict = {k + f'_{i}': v for k, v in l_dict.items()}
973
+ losses.update(l_dict)
974
+ return losses
975
+
976
+
977
+ class MLP(nn.Module):
978
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
979
+ super().__init__()
980
+ self.num_layers = num_layers
981
+ h = [hidden_dim] * (num_layers - 1)
982
+ self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
983
+
984
+ def forward(self, x):
985
+ for i, layer in enumerate(self.layers):
986
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
987
+ return x
988
+
989
+ class LinearLayer(nn.Module):
990
+ """linear layer configurable with layer normalization, dropout, ReLU."""
991
+
992
+ def __init__(self, input_dim, output_dim, layer_norm=True, dropout=0.1, relu=True):
993
+ super(LinearLayer, self).__init__()
994
+ self.relu = relu
995
+ self.layer_norm = layer_norm
996
+ if layer_norm:
997
+ self.LayerNorm = nn.LayerNorm(input_dim)
998
+ layers = [
999
+ nn.Dropout(dropout),
1000
+ nn.Linear(input_dim, output_dim)
1001
+ ]
1002
+ self.net = nn.Sequential(*layers)
1003
+
1004
+ def forward(self, x):
1005
+ """(N, L, D)"""
1006
+
1007
+ if self.layer_norm:
1008
+ x = self.LayerNorm(x)
1009
+ x = self.net(x)
1010
+ if self.relu:
1011
+ x = F.relu(x, inplace=True)
1012
+ return x # (N, L, D)
1013
+
1014
+ class CGDETRConfig:
1015
+ def __init__(self, dset_name='charadesSTA', eval_split_name='val', data_ratio=1.0,
1016
+ results_root='results', exp_id=None, max_es_cnt=200, eval_epoch=5,
1017
+ grad_clip=0.1, eval_untrained=False, resume_all=False, start_epoch=None,
1018
+ max_q_l=-1, max_v_l=-1, clip_length=1, max_windows=5, train_path=None,
1019
+ eval_path=None, no_norm_vfeat=False, no_norm_tfeat=False, v_feat_dirs=None,
1020
+ t_feat_dir=None, v_feat_dim=770, t_feat_dim=4096, ctx_mode='video_tef',
1021
+ position_embedding='sine', enc_layers=3, dec_layers=3, t2v_layers=2,
1022
+ sent_layers=1, moment_layers=1, dummy_layers=2, dim_feedforward=1024,
1023
+ hidden_dim=256, input_dropout=0.5, dropout=0.1, txt_drop_ratio=0,
1024
+ use_txt_pos=False, nheads=8, num_queries=10, num_dummies=45,
1025
+ total_prompts=10, num_prompts=1, pre_norm=False, n_input_proj=2,
1026
+ contrastive_hdim=64, temperature=0.07, saliency_margin=0.2, aux_loss=True,
1027
+ span_loss_type='l1', contrastive_align_loss=False, set_cost_span=10,
1028
+ set_cost_giou=1, set_cost_class=4, lw_saliency=4, lw_wattn=1.0,
1029
+ lw_ms_align=1.0, lw_distill=1.0, span_loss_coef=10, giou_loss_coef=1,
1030
+ label_loss_coef=4, eos_coef=0.1, contrastive_align_loss_coef=0.02,
1031
+ no_sort_results=False, max_before_nms=10, max_after_nms=10,
1032
+ conf_thd=0.0, nms_thd=-1):
1033
+
1034
+ self.dset_name = dset_name
1035
+ self.eval_split_name = eval_split_name
1036
+ self.data_ratio = data_ratio
1037
+ self.results_root = results_root
1038
+ self.exp_id = exp_id
1039
+ self.max_es_cnt = max_es_cnt
1040
+ self.eval_epoch = eval_epoch
1041
+ self.grad_clip = grad_clip
1042
+ self.eval_untrained = eval_untrained
1043
+ self.resume_all = resume_all
1044
+ self.start_epoch = start_epoch
1045
+ self.max_q_l = max_q_l
1046
+ self.max_v_l = max_v_l
1047
+ self.clip_length = clip_length
1048
+ self.max_windows = max_windows
1049
+ self.train_path = train_path
1050
+ self.eval_path = eval_path
1051
+ self.no_norm_vfeat = no_norm_vfeat
1052
+ self.no_norm_tfeat = no_norm_tfeat
1053
+ self.v_feat_dirs = v_feat_dirs
1054
+ self.t_feat_dir = t_feat_dir
1055
+ self.v_feat_dim = v_feat_dim
1056
+ self.t_feat_dim = t_feat_dim
1057
+ self.ctx_mode = ctx_mode
1058
+ self.position_embedding = position_embedding
1059
+ self.enc_layers = enc_layers
1060
+ self.dec_layers = dec_layers
1061
+ self.t2v_layers = t2v_layers
1062
+ self.sent_layers = sent_layers
1063
+ self.moment_layers = moment_layers
1064
+ self.dummy_layers = dummy_layers
1065
+ self.dim_feedforward = dim_feedforward
1066
+ self.hidden_dim = hidden_dim
1067
+ self.input_dropout = input_dropout
1068
+ self.dropout = dropout
1069
+ self.txt_drop_ratio = txt_drop_ratio
1070
+ self.use_txt_pos = use_txt_pos
1071
+ self.nheads = nheads
1072
+ self.num_queries = num_queries
1073
+ self.num_dummies = num_dummies
1074
+ self.total_prompts = total_prompts
1075
+ self.num_prompts = num_prompts
1076
+ self.pre_norm = pre_norm
1077
+ self.n_input_proj = n_input_proj
1078
+ self.contrastive_hdim = contrastive_hdim
1079
+ self.temperature = temperature
1080
+ self.saliency_margin = saliency_margin
1081
+ self.aux_loss = aux_loss
1082
+ self.span_loss_type = span_loss_type
1083
+ self.contrastive_align_loss = contrastive_align_loss
1084
+ self.set_cost_span = set_cost_span
1085
+ self.set_cost_giou = set_cost_giou
1086
+ self.set_cost_class = set_cost_class
1087
+ self.lw_saliency = lw_saliency
1088
+ self.lw_wattn = lw_wattn
1089
+ self.lw_ms_align = lw_ms_align
1090
+ self.lw_distill = lw_distill
1091
+ self.span_loss_coef = span_loss_coef
1092
+ self.giou_loss_coef = giou_loss_coef
1093
+ self.label_loss_coef = label_loss_coef
1094
+ self.eos_coef = eos_coef
1095
+ self.contrastive_align_loss_coef = contrastive_align_loss_coef
1096
+ self.no_sort_results = no_sort_results
1097
+ self.max_before_nms = max_before_nms
1098
+ self.max_after_nms = max_after_nms
1099
+ self.conf_thd = conf_thd
1100
+ self.nms_thd = nms_thd
1101
+
1102
+ def build_cgdetr_model():
1103
+ # device = torch.device(args.device)
1104
+ # import pdb; pdb.set_trace()
1105
+ args = CGDETRConfig()
1106
+
1107
+ transformer = build_transformer(args)
1108
+ position_embedding, txt_position_embedding = build_position_encoding(args)
1109
+
1110
+ # if args.a_feat_dir is None:
1111
+ model = CGDETR(
1112
+ transformer,
1113
+ position_embedding,
1114
+ txt_position_embedding,
1115
+ txt_dim=args.t_feat_dim,
1116
+ vid_dim=args.v_feat_dim,
1117
+ num_queries=args.num_queries,
1118
+ input_dropout=args.input_dropout,
1119
+ aux_loss=args.aux_loss,
1120
+ contrastive_align_loss=args.contrastive_align_loss,
1121
+ contrastive_hdim=args.contrastive_hdim,
1122
+ span_loss_type=args.span_loss_type,
1123
+ use_txt_pos=args.use_txt_pos,
1124
+ n_input_proj=args.n_input_proj,
1125
+ args=args
1126
+ )
1127
+ # else:
1128
+ # model = CGDETR(
1129
+ # transformer,
1130
+ # position_embedding,
1131
+ # txt_position_embedding,
1132
+ # txt_dim=args.t_feat_dim,
1133
+ # vid_dim=args.v_feat_dim,
1134
+ # aud_dim=args.a_feat_dim,
1135
+ # num_queries=args.num_queries,
1136
+ # input_dropout=args.input_dropout,
1137
+ # aux_loss=args.aux_loss,
1138
+ # contrastive_align_loss=args.contrastive_align_loss,
1139
+ # contrastive_hdim=args.contrastive_hdim,
1140
+ # span_loss_type=args.span_loss_type,
1141
+ # use_txt_pos=args.use_txt_pos,
1142
+ # n_input_proj=args.n_input_proj,
1143
+ # args=args
1144
+ # )
1145
+
1146
+ matcher = build_matcher(args)
1147
+ weight_dict = {"loss_span": args.span_loss_coef,
1148
+ "loss_giou": args.giou_loss_coef,
1149
+ "loss_label": args.label_loss_coef,
1150
+ "loss_saliency": args.lw_saliency,
1151
+ "loss_ms_align": args.lw_ms_align,
1152
+ "loss_distill": args.lw_distill,
1153
+ "loss_orthogonal_dummy":args.lw_distill}
1154
+ if args.contrastive_align_loss:
1155
+ weight_dict["loss_contrastive_align"] = args.contrastive_align_loss_coef
1156
+
1157
+ if args.aux_loss:
1158
+ aux_weight_dict = {}
1159
+ for i in range(args.dec_layers - 1):
1160
+ aux_weight_dict.update({k + f'_{i}': v for k, v in weight_dict.items() if k != "loss_saliency"})
1161
+ weight_dict.update(aux_weight_dict)
1162
+
1163
+ losses = ['spans', 'labels', 'saliency', 'ms_align', 'distill', 'orthogonal_dummy']
1164
+ if args.contrastive_align_loss:
1165
+ losses += ["contrastive_align"]
1166
+
1167
+ # For highlight detection datasets
1168
+ # use_matcher = not (args.dset_name in ['youtube_uni', 'tvsum'])
1169
+ use_matcher = True
1170
+
1171
+ criterion = SetCriterion(
1172
+ matcher=matcher, weight_dict=weight_dict, losses=losses,
1173
+ eos_coef=args.eos_coef, temperature=args.temperature,
1174
+ span_loss_type=args.span_loss_type, max_v_l=args.max_v_l,
1175
+ saliency_margin=args.saliency_margin, use_matcher=use_matcher, args=args
1176
+ )
1177
+ # criterion.to(device)
1178
+ return model, criterion
third_party/cgdetr/cg_detr/position_encoding.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+ """
3
+ Various positional encodings for the transformer.
4
+ """
5
+ import math
6
+ import torch
7
+ from torch import nn
8
+
9
+
10
+ class TrainablePositionalEncoding(nn.Module):
11
+ """Construct the embeddings from word, position and token_type embeddings.
12
+ """
13
+ def __init__(self, max_position_embeddings, hidden_size, dropout=0.1):
14
+ super(TrainablePositionalEncoding, self).__init__()
15
+ self.position_embeddings = nn.Embedding(max_position_embeddings, hidden_size)
16
+ self.LayerNorm = nn.LayerNorm(hidden_size)
17
+ self.dropout = nn.Dropout(dropout)
18
+
19
+ def forward(self, input_feat):
20
+ """
21
+ Args:
22
+ input_feat: (N, L, D)
23
+ """
24
+ bsz, seq_length = input_feat.shape[:2]
25
+ position_ids = torch.arange(seq_length, dtype=torch.long, device=input_feat.device)
26
+ position_ids = position_ids.unsqueeze(0).repeat(bsz, 1) # (N, L)
27
+
28
+ position_embeddings = self.position_embeddings(position_ids)
29
+
30
+ embeddings = self.LayerNorm(input_feat + position_embeddings)
31
+ embeddings = self.dropout(embeddings)
32
+ return embeddings
33
+
34
+
35
+ class PositionEmbeddingSine(nn.Module):
36
+ """
37
+ This is a more standard version of the position embedding, very similar to the one
38
+ used by the Attention is all you need paper, generalized to work on images. (To 1D sequences)
39
+ """
40
+ def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
41
+ super().__init__()
42
+ self.num_pos_feats = num_pos_feats
43
+ self.temperature = temperature
44
+ self.normalize = normalize
45
+ if scale is not None and normalize is False:
46
+ raise ValueError("normalize should be True if scale is passed")
47
+ if scale is None:
48
+ scale = 2 * math.pi
49
+ self.scale = scale
50
+
51
+ def forward(self, x, mask):
52
+ """
53
+ Args:
54
+ x: torch.tensor, (batch_size, L, d)
55
+ mask: torch.tensor, (batch_size, L), with 1 as valid
56
+
57
+ Returns:
58
+
59
+ """
60
+ assert mask is not None
61
+ x_embed = mask.cumsum(1, dtype=torch.float32) # (bsz, L)
62
+ if self.normalize:
63
+ eps = 1e-6
64
+ x_embed = x_embed / (x_embed[:, -1:] + eps) * self.scale
65
+
66
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
67
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
68
+
69
+ pos_x = x_embed[:, :, None] / dim_t # (bsz, L, num_pos_feats)
70
+ pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2) # (bsz, L, num_pos_feats*2)
71
+ # import ipdb; ipdb.set_trace()
72
+ return pos_x # .permute(0, 2, 1) # (bsz, num_pos_feats*2, L)
73
+
74
+
75
+ class PositionEmbeddingLearned(nn.Module):
76
+ """
77
+ Absolute pos embedding, learned.
78
+ """
79
+ def __init__(self, num_pos_feats=256):
80
+ super().__init__()
81
+ self.row_embed = nn.Embedding(50, num_pos_feats)
82
+ self.col_embed = nn.Embedding(50, num_pos_feats)
83
+ self.reset_parameters()
84
+
85
+ def reset_parameters(self):
86
+ nn.init.uniform_(self.row_embed.weight)
87
+ nn.init.uniform_(self.col_embed.weight)
88
+
89
+ def forward(self, x, mask):
90
+ h, w = x.shape[-2:]
91
+ i = torch.arange(w, device=x.device)
92
+ j = torch.arange(h, device=x.device)
93
+ x_emb = self.col_embed(i)
94
+ y_emb = self.row_embed(j)
95
+ pos = torch.cat([
96
+ x_emb.unsqueeze(0).repeat(h, 1, 1),
97
+ y_emb.unsqueeze(1).repeat(1, w, 1),
98
+ ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1)
99
+ return pos
100
+
101
+
102
+ def build_position_encoding(args):
103
+ N_steps = args.hidden_dim
104
+ if args.position_embedding in ('v2', 'sine'):
105
+ # TODO find a better way of exposing other arguments
106
+ position_embedding = PositionEmbeddingSine(N_steps, normalize=True)
107
+ # elif args.position_embedding in ('v3', 'learned'):
108
+ # position_embedding = PositionEmbeddingLearned(N_steps)
109
+ else:
110
+ raise ValueError(f"not supported {args.position_embedding}")
111
+ if args.max_q_l == -1:
112
+ args.max_q_l = 100
113
+ txt_pos_embed = TrainablePositionalEncoding(
114
+ max_position_embeddings=args.max_q_l,
115
+ hidden_size=args.hidden_dim, dropout=args.input_dropout)
116
+ return position_embedding, txt_pos_embed
third_party/cgdetr/cg_detr/postprocessing_cg_detr.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pprint
2
+ import numpy as np
3
+ import torch
4
+ from third_party.cgdetr.utils.basic_utils import load_jsonl
5
+ from third_party.cgdetr.standalone_eval.eval import eval_submission
6
+ from tqdm import tqdm
7
+
8
+
9
+ class PostProcessorDETR:
10
+ def __init__(self, clip_length=2, min_ts_val=0, max_ts_val=150,
11
+ min_w_l=2, max_w_l=70, move_window_method="center",
12
+ process_func_names=("clip_window_l", "clip_ts", "round_multiple")):
13
+ self.clip_length = clip_length
14
+ self.min_ts_val = min_ts_val
15
+ self.max_ts_val = max_ts_val
16
+ self.min_w_l = min_w_l
17
+ self.max_w_l = max_w_l
18
+ self.move_window_method = move_window_method
19
+ self.process_func_names = process_func_names
20
+ self.name2func = dict(
21
+ clip_ts=self.clip_min_max_timestamps,
22
+ round_multiple=self.round_to_multiple_clip_lengths,
23
+ clip_window_l=self.clip_window_lengths
24
+ )
25
+
26
+ def __call__(self, lines):
27
+ processed_lines = []
28
+ for line in tqdm(lines, desc=f"convert to multiples of clip_length={self.clip_length}"):
29
+ windows_and_scores = torch.tensor(line["pred_relevant_windows"])
30
+ windows = windows_and_scores[:, :2]
31
+ for func_name in self.process_func_names:
32
+ windows = self.name2func[func_name](windows)
33
+ line["pred_relevant_windows"] = torch.cat(
34
+ [windows, windows_and_scores[:, 2:3]], dim=1).tolist()
35
+ line["pred_relevant_windows"] = [e[:2] + [float(f"{e[2]:.4f}")] for e in line["pred_relevant_windows"]]
36
+ processed_lines.append(line)
37
+ return processed_lines
38
+
39
+ def clip_min_max_timestamps(self, windows):
40
+ """
41
+ windows: (#windows, 2) torch.Tensor
42
+ ensure timestamps for all windows is within [min_val, max_val], clip is out of boundaries.
43
+ """
44
+ return torch.clamp(windows, min=self.min_ts_val, max=self.max_ts_val)
45
+
46
+ def round_to_multiple_clip_lengths(self, windows):
47
+ """
48
+ windows: (#windows, 2) torch.Tensor
49
+ ensure the final window timestamps are multiples of `clip_length`
50
+ """
51
+ return torch.round(windows / self.clip_length) * self.clip_length
52
+
53
+ def clip_window_lengths(self, windows):
54
+ """
55
+ windows: (#windows, 2) np.ndarray
56
+ ensure the final window duration are within [self.min_w_l, self.max_w_l]
57
+ """
58
+ window_lengths = windows[:, 1] - windows[:, 0]
59
+ small_rows = window_lengths < self.min_w_l
60
+ if torch.sum(small_rows) > 0:
61
+ windows = self.move_windows(
62
+ windows, small_rows, self.min_w_l, move_method=self.move_window_method)
63
+ large_rows = window_lengths > self.max_w_l
64
+ if torch.sum(large_rows) > 0:
65
+ windows = self.move_windows(
66
+ windows, large_rows, self.max_w_l, move_method=self.move_window_method)
67
+ return windows
68
+
69
+ @classmethod
70
+ def move_windows(cls, windows, row_selector, new_length, move_method="left"):
71
+ """
72
+ Args:
73
+ windows:
74
+ row_selector:
75
+ new_length:
76
+ move_method: str,
77
+ left: keep left unchanged
78
+ center: keep center unchanged
79
+ right: keep right unchanged
80
+
81
+ Returns:
82
+
83
+ """
84
+ # import ipdb;
85
+ # ipdb.set_trace()
86
+ if move_method == "left":
87
+ windows[row_selector, 1] = windows[row_selector, 0] + new_length
88
+ elif move_method == "right":
89
+ windows[row_selector, 0] = windows[row_selector, 1] - new_length
90
+ elif move_method == "center":
91
+ center = (windows[row_selector, 1] + windows[row_selector, 0]) / 2.
92
+ windows[row_selector, 0] = center - new_length / 2.
93
+ windows[row_selector, 1] = center + new_length / 2.
94
+ return windows
95
+
third_party/cgdetr/cg_detr/scripts/charades_sta/inference.sh ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ ckpt_path=$1
2
+ eval_split_name=$2
3
+ eval_path=data/highlight_${eval_split_name}_release.jsonl
4
+ PYTHONPATH=$PYTHONPATH:. python cg_detr/inference.py \
5
+ --resume ${ckpt_path} \
6
+ --eval_split_name ${eval_split_name} \
7
+ --eval_path ${eval_path} \
8
+ ${@:3}
third_party/cgdetr/cg_detr/scripts/charades_sta/train.sh ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dset_name=charadesSTA
2
+ ctx_mode=video_tef
3
+ v_feat_types=intern
4
+ t_feat_type=intern
5
+ results_root=results_charades
6
+ exp_id=exp
7
+
8
+ ######## data paths
9
+ train_path=data/charades_sta/charades_sta_train_tvr_format.jsonl
10
+ eval_path=data/charades_sta/charades_sta_test_tvr_format.jsonl
11
+ eval_split_name=val
12
+
13
+ ######## setup video+text features
14
+ feat_root=/mnt/petrelfs/lizhilin/CGDETR-main/features/charades
15
+
16
+ # video features
17
+ v_feat_dim=0
18
+ v_feat_dirs=()
19
+ if [[ ${v_feat_types} == *"slowfast"* ]]; then
20
+ v_feat_dirs+=(${feat_root}/slowfast_features)
21
+ (( v_feat_dim += 2304 )) # double brackets for arithmetic op, no need to use ${v_feat_dim}
22
+ fi
23
+ if [[ ${v_feat_types} == *"clip"* ]]; then
24
+ v_feat_dirs+=(${feat_root}/clip_features)
25
+ (( v_feat_dim += 512 ))
26
+ fi
27
+ if [[ ${v_feat_types} == *"intern"* ]]; then
28
+ v_feat_dirs+=(${feat_root}/charade_sta_internvideo2_videoclip_6b_w1s)
29
+ (( v_feat_dim += 768 ))
30
+ fi
31
+
32
+ # text features
33
+ if [[ ${t_feat_type} == "clip" ]]; then
34
+ t_feat_dir=${feat_root}/clip_text_features/
35
+ t_feat_dim=512
36
+ fi
37
+ if [[ ${t_feat_type} == *"intern"* ]]; then
38
+ t_feat_dir=(${feat_root}/charade_sta_internvideo2_llama_text_feature)
39
+ t_feat_dim=4096
40
+ fi
41
+
42
+ #### training
43
+ bsz=32
44
+ eval_bsz=32
45
+ num_dummies=45
46
+ num_prompts=2
47
+ total_prompts=10
48
+ lr_drop=400
49
+ enc_layers=3
50
+ dec_layers=3
51
+ t2v_layers=2
52
+ dummy_layers=2
53
+ moment_layers=1
54
+ sent_layers=1
55
+
56
+ PYTHONPATH=$PYTHONPATH:. \
57
+ srun -p video5 \
58
+ --preempt \
59
+ --job-name=${JOB_NAME} \
60
+ --ntasks=1 \
61
+ --gres=gpu:1 \
62
+ --ntasks-per-node=1 \
63
+ --cpus-per-task=8 \
64
+ --kill-on-bad-exit=1 \
65
+ python cg_detr/train.py \
66
+ --dset_name ${dset_name} \
67
+ --ctx_mode ${ctx_mode} \
68
+ --train_path ${train_path} \
69
+ --eval_path ${eval_path} \
70
+ --eval_split_name ${eval_split_name} \
71
+ --v_feat_dirs ${v_feat_dirs[@]} \
72
+ --v_feat_dim ${v_feat_dim} \
73
+ --t_feat_dir ${t_feat_dir} \
74
+ --t_feat_dim ${t_feat_dim} \
75
+ --bsz ${bsz} \
76
+ --results_root ${results_root} \
77
+ --exp_id ${exp_id} \
78
+ --max_v_l -1 \
79
+ --clip_length 1 \
80
+ --lr 0.0002 \
81
+ --lr_drop ${lr_drop} \
82
+ --n_epoch 200 \
83
+ --contrastive_align_loss_coef 0.002 \
84
+ --lw_saliency 4 \
85
+ --enc_layers ${enc_layers} \
86
+ --dec_layers ${dec_layers} \
87
+ --t2v_layers ${t2v_layers} \
88
+ --moment_layers ${moment_layers} \
89
+ --dummy_layers ${dummy_layers} \
90
+ --sent_layers ${sent_layers} \
91
+ --eval_bsz ${eval_bsz} \
92
+ --num_dummies ${num_dummies} \
93
+ --num_prompts ${num_prompts} \
94
+ --total_prompts ${total_prompts} \
95
+ ${@:1}
third_party/cgdetr/cg_detr/scripts/inference.sh ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ckpt_path=$1
2
+ eval_split_name=$2
3
+ eval_path=data/highlight_${eval_split_name}_release.jsonl
4
+ echo ${ckpt_path}
5
+ echo ${eval_split_name}
6
+ echo ${eval_path}
7
+ PYTHONPATH=$PYTHONPATH:. python cg_detr/inference.py \
8
+ --resume ${ckpt_path} \
9
+ --eval_split_name ${eval_split_name} \
10
+ --eval_path ${eval_path} \
11
+ ${@:3}
third_party/cgdetr/cg_detr/scripts/train.sh ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dset_name=hl
2
+ ctx_mode=video_tef
3
+ v_feat_types=intern
4
+ t_feat_type=intern
5
+ results_root=results_qvhighlights
6
+ exp_id=exp
7
+
8
+ ######## data paths
9
+ train_path=data/highlight_train_release.jsonl
10
+ eval_path=data/highlight_val_release.jsonl
11
+ eval_split_name=val
12
+
13
+ ######## setup video+text features
14
+ feat_root=../features/qvhighlight
15
+
16
+ # video features
17
+ v_feat_dim=0
18
+ v_feat_dirs=()
19
+ if [[ ${v_feat_types} == *"slowfast"* ]]; then
20
+ v_feat_dirs+=(${feat_root}/slowfast_features)
21
+ (( v_feat_dim += 2304 )) # double brackets for arithmetic op, no need to use ${v_feat_dim}
22
+ fi
23
+ if [[ ${v_feat_types} == *"clip"* ]]; then
24
+ v_feat_dirs+=(${feat_root}/clip_features)
25
+ (( v_feat_dim += 512 ))
26
+ fi
27
+ if [[ ${v_feat_types} == *"intern"* ]]; then
28
+ v_feat_dirs+=(${feat_root}/qvhighlight_internvideo2_videoclip_6b_w2s)
29
+ (( v_feat_dim += 768 ))
30
+ fi
31
+
32
+ # text features
33
+ if [[ ${t_feat_type} == "clip" ]]; then
34
+ t_feat_dir=${feat_root}/clip_text_features/
35
+ t_feat_dim=512
36
+ fi
37
+ if [[ ${t_feat_type} == *"intern"* ]]; then
38
+ t_feat_dir=(${feat_root}/qvhighlight_internvideo2_llama_text_feature)
39
+ t_feat_dim=4096
40
+ fi
41
+
42
+
43
+ #### training
44
+ bsz=32
45
+ enc_layers=3
46
+ dec_layers=3
47
+ t2v_layers=2
48
+ moment_layers=1
49
+ dummy_layers=2
50
+ sent_layers=1
51
+ max_v_l=75
52
+ max_q_l=32
53
+
54
+ PYTHONPATH=$PYTHONPATH:. python cg_detr/train.py \
55
+ --dset_name ${dset_name} \
56
+ --ctx_mode ${ctx_mode} \
57
+ --train_path ${train_path} \
58
+ --eval_path ${eval_path} \
59
+ --eval_split_name ${eval_split_name} \
60
+ --v_feat_dirs ${v_feat_dirs[@]} \
61
+ --v_feat_dim ${v_feat_dim} \
62
+ --t_feat_dir ${t_feat_dir} \
63
+ --t_feat_dim ${t_feat_dim} \
64
+ --bsz ${bsz} \
65
+ --lr 0.0002 \
66
+ --results_root ${results_root} \
67
+ --exp_id ${exp_id} \
68
+ --enc_layers ${enc_layers} \
69
+ --dec_layers ${dec_layers} \
70
+ --t2v_layers ${t2v_layers} \
71
+ --moment_layers ${moment_layers} \
72
+ --dummy_layers ${dummy_layers} \
73
+ --sent_layers ${sent_layers} \
74
+ --max_v_l ${max_v_l} \
75
+ --max_q_l ${max_q_l} \
76
+ ${@:1}
third_party/cgdetr/cg_detr/span_utils.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def span_xx_to_cxw(xx_spans):
5
+ """
6
+ Args:
7
+ xx_spans: tensor, (#windows, 2) or (..., 2), each row is a window of format (st, ed)
8
+
9
+ Returns:
10
+ cxw_spans: tensor, (#windows, 2), each row is a window of format (center=(st+ed)/2, width=(ed-st))
11
+ >>> spans = torch.Tensor([[0, 1], [0.2, 0.4]])
12
+ >>> span_xx_to_cxw(spans)
13
+ tensor([[0.5000, 1.0000],
14
+ [0.3000, 0.2000]])
15
+ >>> spans = torch.Tensor([[[0, 1], [0.2, 0.4]]])
16
+ >>> span_xx_to_cxw(spans)
17
+ tensor([[[0.5000, 1.0000],
18
+ [0.3000, 0.2000]]])
19
+ """
20
+ center = xx_spans.sum(-1) * 0.5
21
+ width = xx_spans[..., 1] - xx_spans[..., 0]
22
+ return torch.stack([center, width], dim=-1)
23
+
24
+
25
+ def span_cxw_to_xx(cxw_spans):
26
+ """
27
+ Args:
28
+ cxw_spans: tensor, (#windows, 2) or (..., 2), the last dim is a row denoting a window of format (center, width)
29
+
30
+ >>> spans = torch.Tensor([[0.5000, 1.0000], [0.3000, 0.2000]])
31
+ >>> span_cxw_to_xx(spans)
32
+ tensor([[0.0000, 1.0000],
33
+ [0.2000, 0.4000]])
34
+ >>> spans = torch.Tensor([[[0.5000, 1.0000], [0.3000, 0.2000]]])
35
+ >>> span_cxw_to_xx(spans)
36
+ tensor([[[0.0000, 1.0000],
37
+ [0.2000, 0.4000]]])
38
+ """
39
+ x1 = cxw_spans[..., 0] - 0.5 * cxw_spans[..., 1]
40
+ x2 = cxw_spans[..., 0] + 0.5 * cxw_spans[..., 1]
41
+ return torch.stack([x1, x2], dim=-1)
42
+
43
+
44
+ def temporal_iou(spans1, spans2):
45
+ """
46
+ Args:
47
+ spans1: (N, 2) torch.Tensor, each row defines a span [st, ed]
48
+ spans2: (M, 2) torch.Tensor, ...
49
+
50
+ Returns:
51
+ iou: (N, M) torch.Tensor
52
+ union: (N, M) torch.Tensor
53
+ >>> test_spans1 = torch.Tensor([[0, 0.2], [0.5, 1.0]])
54
+ >>> test_spans2 = torch.Tensor([[0, 0.3], [0., 1.0]])
55
+ >>> temporal_iou(test_spans1, test_spans2)
56
+ (tensor([[0.6667, 0.2000],
57
+ [0.0000, 0.5000]]),
58
+ tensor([[0.3000, 1.0000],
59
+ [0.8000, 1.0000]]))
60
+ """
61
+ areas1 = spans1[:, 1] - spans1[:, 0] # (N, )
62
+ areas2 = spans2[:, 1] - spans2[:, 0] # (M, )
63
+
64
+ left = torch.max(spans1[:, None, 0], spans2[:, 0]) # (N, M)
65
+ right = torch.min(spans1[:, None, 1], spans2[:, 1]) # (N, M)
66
+
67
+ inter = (right - left).clamp(min=0) # (N, M)
68
+ union = areas1[:, None] + areas2 - inter # (N, M)
69
+
70
+ iou = inter / union
71
+ return iou, union
72
+
73
+
74
+ def temporal_intersection_over_pred(gt_spans, pred_spans):
75
+ """ intersection over the second input spans
76
+ Args:
77
+ gt_spans: (N, 2),
78
+ pred_spans: (M, 2)
79
+
80
+ Returns:
81
+
82
+ """
83
+ left = torch.max(gt_spans[:, None, 0], pred_spans[:, 0])
84
+ right = torch.min(gt_spans[:, None, 1], pred_spans[:, 1])
85
+
86
+ inter = (right - left).clamp(min=0) # (N, M)
87
+ inter_over_pred = inter / (pred_spans[:, 1] - pred_spans[:, 0])
88
+ return inter_over_pred
89
+
90
+
91
+ def generalized_temporal_iou(spans1, spans2):
92
+ """
93
+ Generalized IoU from https://giou.stanford.edu/
94
+ Also reference to DETR implementation of generalized_box_iou
95
+ https://github.com/facebookresearch/detr/blob/master/util/box_ops.py#L40
96
+
97
+ Args:
98
+ spans1: (N, 2) torch.Tensor, each row defines a span in xx format [st, ed]
99
+ spans2: (M, 2) torch.Tensor, ...
100
+
101
+ Returns:
102
+ giou: (N, M) torch.Tensor
103
+
104
+ >>> test_spans1 = torch.Tensor([[0, 0.2], [0.5, 1.0]])
105
+ >>> test_spans2 = torch.Tensor([[0, 0.3], [0., 1.0]])
106
+ >>> generalized_temporal_iou(test_spans1, test_spans2)
107
+ tensor([[ 0.6667, 0.2000],
108
+ [-0.2000, 0.5000]])
109
+ """
110
+ spans1 = spans1.float()
111
+ spans2 = spans2.float()
112
+
113
+ if (spans1[:, 1] < spans1[:, 0]).all():
114
+ torch.save({'spans1': spans1.cpu(), 'spans2': spans2.cpu()}, 'test_spans.pt')
115
+ spans1[:, 1] += 0.0001
116
+ print(spans1)
117
+ assert (spans1[:, 1] >= spans1[:, 0]).all()
118
+ assert (spans2[:, 1] >= spans2[:, 0]).all()
119
+ iou, union = temporal_iou(spans1, spans2)
120
+
121
+ left = torch.min(spans1[:, None, 0], spans2[:, 0]) # (N, M)
122
+ right = torch.max(spans1[:, None, 1], spans2[:, 1]) # (N, M)
123
+ enclosing_area = (right - left).clamp(min=0) # (N, M)
124
+
125
+ return iou - (enclosing_area - union) / enclosing_area
126
+
127
+
third_party/cgdetr/cg_detr/start_end_dataset.py ADDED
@@ -0,0 +1,383 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import Dataset
3
+ import numpy as np
4
+ from tqdm import tqdm
5
+ import random
6
+ import logging
7
+ from os.path import join, exists
8
+ from third_party.cgdetr.utils.basic_utils import load_jsonl, l2_normalize_np_array
9
+ from third_party.cgdetr.utils.tensor_utils import pad_sequences_1d
10
+ from third_party.cgdetr.cg_detr.span_utils import span_xx_to_cxw
11
+ # from torchtext import vocab
12
+ import torch.nn as nn
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ class StartEndDataset(Dataset):
18
+ Q_FEAT_TYPES = ["pooler_output", "last_hidden_state"]
19
+ """One line in data loaded from data_path."
20
+ {
21
+ "qid": 7803,
22
+ "query": "Man in gray top walks from outside to inside.",
23
+ "duration": 150,
24
+ "vid": "RoripwjYFp8_360.0_510.0",
25
+ "relevant_clip_ids": [13, 14, 15, 16, 17],
26
+ "relevant_windows": [[26, 36]]
27
+ }
28
+ """
29
+
30
+ def __init__(self, dset_name, data_path, v_feat_dirs, q_feat_dir,
31
+ q_feat_type="last_hidden_state",
32
+ max_q_l=32, max_v_l=75, data_ratio=1.0, ctx_mode="video",
33
+ normalize_v=True, normalize_t=True, load_labels=True,
34
+ clip_len=2, max_windows=5, span_loss_type="l1", txt_drop_ratio=0,
35
+ dset_domain=None):
36
+ self.dset_name = dset_name
37
+ self.data_path = data_path
38
+ self.data_ratio = data_ratio
39
+ self.v_feat_dirs = v_feat_dirs \
40
+ if isinstance(v_feat_dirs, list) else [v_feat_dirs]
41
+ self.q_feat_dir = q_feat_dir
42
+ self.q_feat_type = q_feat_type
43
+ if max_v_l == -1:
44
+ max_v_l = 100000000
45
+ if max_q_l == -1:
46
+ max_q_l = 100
47
+ self.max_q_l = max_q_l
48
+ self.max_v_l = max_v_l
49
+ self.ctx_mode = ctx_mode
50
+ self.use_tef = "tef" in ctx_mode
51
+ self.use_video = "video" in ctx_mode
52
+ self.normalize_t = normalize_t
53
+ self.normalize_v = normalize_v
54
+ self.load_labels = load_labels
55
+ self.clip_len = clip_len
56
+ self.max_windows = max_windows # maximum number of windows to use as labels
57
+ self.span_loss_type = span_loss_type
58
+ self.txt_drop_ratio = txt_drop_ratio
59
+ if "val" in data_path or "test" in data_path:
60
+ assert txt_drop_ratio == 0
61
+
62
+ if self.dset_name == 'hl':
63
+ self.max_q_l = 32
64
+ self.max_v_l = 75
65
+ self.clip_len = 2
66
+
67
+ # checks
68
+ assert q_feat_type in self.Q_FEAT_TYPES
69
+
70
+ # data
71
+ self.data = self.load_data()
72
+
73
+ self.use_glove = False
74
+ self.use_glove = 'vgg' in self.v_feat_dirs[0]
75
+
76
+ # if self.dset_name == 'charadesSTA' and self.use_glove:
77
+ # self.vocab = vocab.pretrained_aliases['glove.6B.300d']()
78
+ # self.vocab.itos.extend(['<unk>'])
79
+ # self.vocab.stoi['<unk>'] = self.vocab.vectors.shape[0]
80
+ # self.vocab.vectors = torch.cat(
81
+ # (self.vocab.vectors, torch.zeros(1, self.vocab.dim)), dim=0)
82
+ # self.embedding = nn.Embedding.from_pretrained(self.vocab.vectors)
83
+
84
+
85
+ def load_data(self):
86
+ datalist = load_jsonl(self.data_path)
87
+ if self.data_ratio != 1:
88
+ n_examples = int(len(datalist) * self.data_ratio)
89
+ datalist = datalist[:n_examples]
90
+ logger.info("Using {}% of the data: {} examples"
91
+ .format(self.data_ratio * 100, n_examples))
92
+ return datalist
93
+
94
+ def __len__(self):
95
+ return len(self.data)
96
+
97
+
98
+ def __getitem__(self, index):
99
+ meta = self.data[index]
100
+
101
+ model_inputs = dict()
102
+
103
+ if self.use_glove: # False
104
+ model_inputs["query_feat"] = self.get_query(meta["query"])
105
+ else:
106
+ model_inputs["query_feat"] = self._get_query_feat_by_qid(meta["qid"]) # (Dq, ) or (Lq, Dq) # [16, 4096]
107
+
108
+
109
+ if self.use_video : # True
110
+ model_inputs["video_feat"] = self._get_video_feat_by_vid(meta["vid"]) # (Lv, Dv)
111
+ ctx_l = len(model_inputs["video_feat"])
112
+ else:
113
+ ctx_l = self.max_v_l
114
+
115
+
116
+ if self.use_tef:
117
+ tef_st = torch.arange(0, ctx_l, 1.0) / ctx_l
118
+ tef_ed = tef_st + 1.0 / ctx_l
119
+ tef = torch.stack([tef_st, tef_ed], dim=1) # (Lv, 2)
120
+ if self.use_video :
121
+ model_inputs["video_feat"] = torch.cat(
122
+ [model_inputs["video_feat"], tef], dim=1) # (Lv, Dv+2)
123
+ else:
124
+ model_inputs["video_feat"] = tef
125
+
126
+
127
+
128
+ if "relevant_windows" in meta: ## For Qvhighlights test set
129
+ model_inputs["span_labels"] = self.get_span_labels(meta["relevant_windows"], ctx_l) # (#windows, 2)
130
+ if self.dset_name in ['charadesSTA', 'tacos', 'activitynet']: ## charades, tacos, nlq
131
+ model_inputs["saliency_pos_labels"], model_inputs["saliency_neg_labels"], model_inputs["saliency_all_labels"] = \
132
+ self.get_saliency_labels_sub_as_query(meta["relevant_windows"][0], meta["duration"], ctx_l) # only one gt
133
+ elif "subs_train" not in self.data_path:
134
+ model_inputs["saliency_pos_labels"], model_inputs["saliency_neg_labels"], model_inputs["saliency_all_labels"] = \
135
+ self.get_saliency_labels_all(meta["relevant_clip_ids"], meta["saliency_scores"], ctx_l)
136
+ else:
137
+ model_inputs["saliency_pos_labels"], model_inputs["saliency_neg_labels"], model_inputs["saliency_all_labels"] = \
138
+ self.get_saliency_labels_sub_as_query(meta["relevant_windows"][0], meta["duration"], ctx_l) # only one gt
139
+
140
+ if 'qvhighlight' or 'qvhl' in self.data_path:
141
+ model_inputs["relevant_clip_ids"] = meta["relevant_clip_ids"]
142
+ model_inputs["vid"] = meta["vid"]
143
+ model_inputs["qid"] = meta["qid"]
144
+ return dict(meta=meta, model_inputs=model_inputs)
145
+
146
+ # def get_query(self, query):
147
+ # word_inds = torch.LongTensor(
148
+ # [self.vocab.stoi.get(w.lower(), 400000) for w in query.split()])
149
+ # return self.embedding(word_inds)
150
+ def get_query(self, query):
151
+ print("ERROR")
152
+ exit()
153
+
154
+ def get_saliency_labels_sub_as_query(self, gt_window, duration, ctx_l, max_n=2):
155
+ clip_len = duration / ctx_l
156
+ gt_st = int(gt_window[0] / clip_len)
157
+ gt_ed = max(0, min(int(gt_window[1] / clip_len), ctx_l) - 1)
158
+ if gt_st > gt_ed:
159
+ gt_st = gt_ed
160
+
161
+ if gt_st != gt_ed:
162
+ pos_clip_indices = random.sample(range(gt_st, gt_ed + 1), k=max_n) # 在GT frame idx中随机选两个
163
+ else:
164
+ if self.dset_name == 'nlq':
165
+ pos_clip_indices = [gt_st] * 2
166
+ else:
167
+ pos_clip_indices = [gt_st, gt_st]
168
+
169
+ neg_pool = list(range(0, gt_st)) + list(range(gt_ed+1, ctx_l)) # 非GT的frame idx
170
+ try:
171
+ neg_clip_indices = random.sample(neg_pool, k=max_n) # 在非GT frame idx中随机选两个
172
+ except:
173
+ neg_clip_indices = pos_clip_indices
174
+
175
+ # For charades_sta
176
+ score_array = np.zeros(ctx_l)
177
+ score_array[gt_st:gt_ed + 1] = 1
178
+
179
+ return pos_clip_indices, neg_clip_indices, score_array
180
+
181
+
182
+ def get_saliency_labels(self, rel_clip_ids, scores, ctx_l, max_n=1, add_easy_negative=True):
183
+ """Sum the scores from the three annotations, then take the two clips with the
184
+ maximum scores as positive, and two with the minimum scores as negative.
185
+ Args:
186
+ rel_clip_ids: list(int), list of relevant clip ids
187
+ scores: list([anno1_score, anno2_score, anno3_score]),
188
+ ctx_l: int
189
+ max_n: int, #clips to use as positive and negative, for easy and hard negative, respectively.
190
+ add_easy_negative: bool, if True, sample eay negative outside the relevant_clip_ids.
191
+ """
192
+ # indices inside rel_clip_ids
193
+ scores = np.array(scores) # (#rel_clips, 3)
194
+ agg_scores = np.sum(scores, 1) # (#rel_clips, )
195
+ sort_indices = np.argsort(agg_scores) # increasing
196
+
197
+ # indices in the whole video
198
+ # the min(_, ctx_l-1) here is incorrect, but should not cause
199
+ # much troubles since this should be rarely used.
200
+ hard_pos_clip_indices = [min(rel_clip_ids[idx], ctx_l-1) for idx in sort_indices[-max_n:]]
201
+ hard_neg_clip_indices = [min(rel_clip_ids[idx], ctx_l-1) for idx in sort_indices[:max_n]]
202
+ easy_pos_clip_indices = []
203
+ easy_neg_clip_indices = []
204
+ if add_easy_negative:
205
+ easy_neg_pool = list(set(range(ctx_l)) - set(rel_clip_ids))
206
+ if len(easy_neg_pool) >= max_n:
207
+ easy_pos_clip_indices = random.sample(rel_clip_ids, k=max_n)
208
+ easy_neg_clip_indices = random.sample(easy_neg_pool, k=max_n)
209
+ else: # copy the hard ones
210
+ easy_pos_clip_indices = hard_pos_clip_indices
211
+ easy_neg_clip_indices = hard_neg_clip_indices
212
+
213
+ pos_clip_indices = hard_pos_clip_indices + easy_pos_clip_indices
214
+ neg_clip_indices = hard_neg_clip_indices + easy_neg_clip_indices
215
+ return pos_clip_indices, neg_clip_indices
216
+
217
+ def get_saliency_labels_all(self, rel_clip_ids, scores, ctx_l, max_n=1, add_easy_negative=True):
218
+ """Sum the scores from the three annotations, then take the two clips with the
219
+ maximum scores as positive, and two with the minimum scores as negative.
220
+ Args:
221
+ rel_clip_ids: list(int), list of relevant clip ids
222
+ scores: list([anno1_score, anno2_score, anno3_score]),
223
+ ctx_l: int
224
+ max_n: int, #clips to use as positive and negative, for easy and hard negative, respectively.
225
+ add_easy_negative: bool, if True, sample eay negative outside the relevant_clip_ids.
226
+ """
227
+ # indices inside rel_clip_ids
228
+ scores = np.array(scores) # (#rel_clips, 3)
229
+ agg_scores = np.sum(scores, 1) # (#rel_clips, )
230
+ sort_indices = np.argsort(agg_scores) # increasing
231
+
232
+ # score_array = [min(agg_scores[idx], ctx_l-1) for idx in range(ctx_l)]
233
+ score_array = np.zeros(ctx_l)
234
+ max_len=ctx_l
235
+ for idx in range(len(rel_clip_ids)):
236
+ if rel_clip_ids[idx] >= ctx_l:
237
+ max_len=max(max_len,rel_clip_ids[idx])
238
+ # score_array_new = np.zeros(ctx_l + 1)
239
+ score_array_new = np.zeros(max_len+1)
240
+ # score_array_new[:ctx_l] = score_array
241
+ score_array_new[:len(score_array)] = score_array
242
+ score_array = score_array_new
243
+ score_array[rel_clip_ids[idx]] = agg_scores[idx]
244
+
245
+ # indices in the whole video
246
+ # the min(_, ctx_l-1) here is incorrect, but should not cause
247
+ # much troubles since this should be rarely used.
248
+ hard_pos_clip_indices = [min(rel_clip_ids[idx], ctx_l-1) for idx in sort_indices[-max_n:]]
249
+ hard_neg_clip_indices = [min(rel_clip_ids[idx], ctx_l-1) for idx in sort_indices[:max_n]]
250
+ easy_pos_clip_indices = []
251
+ easy_neg_clip_indices = []
252
+ if add_easy_negative:
253
+ easy_neg_pool = list(set(range(ctx_l)) - set(rel_clip_ids))
254
+ if len(easy_neg_pool) >= max_n:
255
+ easy_pos_clip_indices = random.sample(rel_clip_ids, k=max_n)
256
+ easy_neg_clip_indices = random.sample(easy_neg_pool, k=max_n)
257
+ else: # copy the hard ones
258
+ easy_pos_clip_indices = hard_pos_clip_indices
259
+ easy_neg_clip_indices = hard_neg_clip_indices
260
+
261
+ pos_clip_indices = hard_pos_clip_indices + easy_pos_clip_indices
262
+ neg_clip_indices = hard_neg_clip_indices + easy_neg_clip_indices
263
+ return pos_clip_indices, neg_clip_indices, score_array
264
+
265
+ def get_span_labels(self, windows, ctx_l):
266
+ """
267
+ windows: list([st, ed]) in seconds. E.g. [[26, 36]], corresponding st_ed clip_indices [[13, 17]] (inclusive)
268
+ Note a maximum of `self.max_windows` windows are used.
269
+ returns Tensor of shape (#windows, 2), each row is [center, width] normalized by video length
270
+ """
271
+ if len(windows) > self.max_windows:
272
+ random.shuffle(windows)
273
+ windows = windows[:self.max_windows]
274
+ if self.span_loss_type == "l1":
275
+ windows = torch.Tensor(windows) / (ctx_l * self.clip_len) # normalized windows in xx
276
+ windows = span_xx_to_cxw(windows) # normalized windows in cxw
277
+ elif self.span_loss_type == "ce":
278
+ windows = torch.Tensor([
279
+ [int(w[0] / self.clip_len), min(int(w[1] / self.clip_len), ctx_l) - 1]
280
+ for w in windows]).long() # inclusive
281
+ else:
282
+ raise NotImplementedError
283
+ return windows
284
+
285
+ def _get_query_feat_by_qid(self, qid):
286
+ # QVhighlight dataset
287
+ q_feat_path = join(self.q_feat_dir, f"qid{qid}.pt")
288
+ # q_feat = np.load(q_feat_path)[self.q_feat_type].astype(np.float32)
289
+ q_feat = torch.load(q_feat_path).numpy().astype(np.float32)
290
+ if self.q_feat_type == "last_hidden_state":
291
+ q_feat = q_feat[:self.max_q_l]
292
+ if self.normalize_t:
293
+ q_feat = l2_normalize_np_array(q_feat)
294
+ if self.txt_drop_ratio > 0:
295
+ q_feat = self.random_drop_rows(q_feat)
296
+ return torch.from_numpy(q_feat) # (D, ) or (Lq, D)
297
+
298
+ def random_drop_rows(self, embeddings):
299
+ """randomly mask num_drop rows in embeddings to be zero.
300
+ Args:
301
+ embeddings: np.ndarray (L, D)
302
+ """
303
+ num_drop_rows = round(len(embeddings) * self.txt_drop_ratio)
304
+ if num_drop_rows > 0:
305
+ row_indices = np.random.choice(
306
+ len(embeddings), size=num_drop_rows, replace=False)
307
+ embeddings[row_indices] = 0
308
+ return embeddings
309
+
310
+ def _get_video_feat_by_vid(self, vid):
311
+ v_feat_list = []
312
+ for _feat_dir in self.v_feat_dirs:
313
+ try:
314
+ _feat_path = join(_feat_dir, f"{vid}.pt")
315
+ _feat = torch.load(_feat_path)["features"][:self.max_v_l].numpy().astype(np.float32)
316
+ except:
317
+ _feat_path = join(_feat_dir, f"{vid}.pt")
318
+ _feat = torch.load(_feat_path)[:self.max_v_l].numpy().astype(np.float32)
319
+ if self.normalize_v:
320
+ _feat = l2_normalize_np_array(_feat)
321
+ v_feat_list.append(_feat)
322
+ # some features are slightly longer than the others
323
+ min_len = min([len(e) for e in v_feat_list])
324
+ v_feat_list = [e[:min_len] for e in v_feat_list]
325
+ v_feat = np.concatenate(v_feat_list, axis=1) # (vlen=34, 768)
326
+ return torch.from_numpy(v_feat) # (Lv, D)
327
+
328
+
329
+
330
+ def start_end_collate(batch):
331
+ batch_meta = [e["meta"] for e in batch] # seems no need to collate ?
332
+
333
+ model_inputs_keys = batch[0]["model_inputs"].keys()
334
+ batched_data = dict()
335
+ for k in model_inputs_keys:
336
+ if k == "span_labels":
337
+ batched_data[k] = [dict(spans=e["model_inputs"]["span_labels"]) for e in batch]
338
+ continue
339
+ if k in ["saliency_pos_labels", "saliency_neg_labels"]:
340
+ batched_data[k] = torch.LongTensor([e["model_inputs"][k] for e in batch])
341
+ continue
342
+ if k == "saliency_all_labels":
343
+ pad_data, mask_data = pad_sequences_1d([e["model_inputs"][k] for e in batch], dtype=np.float32, fixed_length=None)
344
+ batched_data[k] = torch.tensor(pad_data, dtype=torch.float32)
345
+ continue
346
+ if k == 'qid':
347
+ batched_data[k] = [e["model_inputs"][k] for e in batch]
348
+ continue
349
+ if k == 'vid':
350
+ batched_data[k] = [e["model_inputs"][k] for e in batch]
351
+ continue
352
+ batched_data[k] = pad_sequences_1d(
353
+ [e["model_inputs"][k] for e in batch], dtype=torch.float32, fixed_length=None)
354
+ return batch_meta, batched_data
355
+
356
+
357
+ def prepare_batch_inputs(batched_model_inputs):
358
+ model_inputs = dict(
359
+ src_txt=batched_model_inputs["query_feat"][0],
360
+ src_txt_mask=batched_model_inputs["query_feat"][1],
361
+ src_vid=batched_model_inputs["video_feat"][0],
362
+ src_vid_mask=batched_model_inputs["video_feat"][1],
363
+ vid=batched_model_inputs["vid"],
364
+ qid=batched_model_inputs["qid"],
365
+ )
366
+ targets = {}
367
+
368
+ # import pdb; pdb.set_trace()
369
+
370
+ if "span_labels" in batched_model_inputs:
371
+ targets["span_labels"] = [
372
+ dict(spans=e["spans"])
373
+ for e in batched_model_inputs["span_labels"]
374
+ ]
375
+ if "saliency_pos_labels" in batched_model_inputs:
376
+ for name in ["saliency_pos_labels", "saliency_neg_labels"]:
377
+ targets[name] = batched_model_inputs[name]
378
+
379
+ if "saliency_all_labels" in batched_model_inputs:
380
+ targets["saliency_all_labels"] = batched_model_inputs["saliency_all_labels"]
381
+ targets["relevant_clips"] = batched_model_inputs["saliency_all_labels"]
382
+ targets = None if len(targets) == 0 else targets
383
+ return model_inputs, targets
third_party/cgdetr/cg_detr/text_encoder.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from easydict import EasyDict as edict
5
+ from xml.model_components import BertAttention, TrainablePositionalEncoding
6
+
7
+
8
+ class TextEncoder(nn.Module):
9
+ def __init__(self, hidden_size, drop, input_drop, nheads, max_position_embeddings):
10
+ super().__init__()
11
+ self.transformer_encoder = BertAttention(edict(
12
+ hidden_size=hidden_size,
13
+ intermediate_size=hidden_size,
14
+ hidden_dropout_prob=drop,
15
+ attention_probs_dropout_prob=drop,
16
+ num_attention_heads=nheads,
17
+ ))
18
+ self.pos_embed = TrainablePositionalEncoding(
19
+ max_position_embeddings=max_position_embeddings,
20
+ hidden_size=hidden_size,
21
+ dropout=input_drop,
22
+ )
23
+ self.modular_vector_mapping = nn.Linear(hidden_size, 1, bias=False)
24
+
25
+ def forward(self, feat, mask):
26
+ """
27
+ Args:
28
+ feat: (N, L, D=hidden_size)
29
+ mask: (N, L) with 1 indicates valid
30
+
31
+ Returns:
32
+ (N, D)
33
+ """
34
+ feat = self.pos_embed(feat) # (N, L, D)
35
+ feat = self.transformer_encoder(feat, mask.unsqueeze(1))
36
+ att_scores = self.modular_vector_mapping(feat) # (N, L, 1)
37
+ att_scores = F.softmax(mask_logits(att_scores, mask.unsqueeze(2)), dim=1)
38
+ pooled_feat = torch.einsum("blm,bld->bmd", att_scores, feat) # (N, 2 or 1, D)
39
+ return pooled_feat.squeeze(1)
40
+
41
+
42
+ def mask_logits(target, mask):
43
+ return target * mask + (1 - mask) * (-1e10)
44
+
45
+
46
+ def build_text_encoder(args):
47
+ return TextEncoder(
48
+ hidden_size=args.hidden_dim,
49
+ drop=args.dropout,
50
+ input_drop=args.input_dropout,
51
+ nheads=args.nheads,
52
+ max_position_embeddings=args.max_q_l
53
+ )
third_party/cgdetr/cg_detr/train.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import json
4
+ import pprint
5
+ import random
6
+ import numpy as np
7
+ from tqdm import tqdm, trange
8
+ from collections import defaultdict
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.backends.cudnn as cudnn
13
+ from torch.utils.data import DataLoader
14
+ from torch.utils.tensorboard import SummaryWriter
15
+
16
+ # import sys
17
+ # print(sys.path)
18
+ # sys.path.insert(os.getcwd(),0)
19
+ # print(sys.path)
20
+
21
+ from cg_detr.config import BaseOptions
22
+ from cg_detr.start_end_dataset import StartEndDataset, start_end_collate, prepare_batch_inputs
23
+ from cg_detr.inference import eval_epoch, start_inference, setup_model
24
+ from utils.basic_utils import AverageMeter, dict_to_markdown
25
+ from utils.model_utils import count_parameters
26
+
27
+
28
+ import logging
29
+ logger = logging.getLogger(__name__)
30
+ logging.basicConfig(format="%(asctime)s.%(msecs)03d:%(levelname)s:%(name)s - %(message)s",
31
+ datefmt="%Y-%m-%d %H:%M:%S",
32
+ level=logging.INFO)
33
+
34
+
35
+ def set_seed(seed, use_cuda=True):
36
+ random.seed(seed)
37
+ np.random.seed(seed)
38
+ torch.manual_seed(seed)
39
+ if use_cuda:
40
+ torch.cuda.manual_seed_all(seed)
41
+
42
+
43
+ def train_epoch(model, criterion, train_loader, optimizer, opt, epoch_i, tb_writer):
44
+ logger.info(f'[Epoch {epoch_i+1}]')
45
+ model.train()
46
+ criterion.train()
47
+
48
+ # init meters
49
+ time_meters = defaultdict(AverageMeter)
50
+ loss_meters = defaultdict(AverageMeter)
51
+
52
+ num_training_examples = len(train_loader)
53
+ timer_dataloading = time.time()
54
+ for batch_idx, batch in tqdm(enumerate(train_loader),
55
+ desc="Training Iteration",
56
+ total=num_training_examples):
57
+ time_meters["dataloading_time"].update(time.time() - timer_dataloading)
58
+ timer_start = time.time()
59
+ model_inputs, targets = prepare_batch_inputs(batch[1], opt.device, non_blocking=opt.pin_memory)
60
+ time_meters["prepare_inputs_time"].update(time.time() - timer_start)
61
+ timer_start = time.time()
62
+
63
+ outputs = model(**model_inputs, targets=targets)
64
+ loss_dict = criterion(outputs, targets)
65
+ weight_dict = criterion.weight_dict
66
+ losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)
67
+ time_meters["model_forward_time"].update(time.time() - timer_start)
68
+
69
+ timer_start = time.time()
70
+ optimizer.zero_grad()
71
+ losses.backward()
72
+ if opt.grad_clip > 0:
73
+ nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip)
74
+ optimizer.step()
75
+ time_meters["model_backward_time"].update(time.time() - timer_start)
76
+
77
+ loss_dict["loss_overall"] = float(losses) # for logging only
78
+ for k, v in loss_dict.items():
79
+ loss_meters[k].update(float(v) * weight_dict[k] if k in weight_dict else float(v))
80
+
81
+ timer_dataloading = time.time()
82
+ if opt.debug and batch_idx == 3:
83
+ break
84
+
85
+ # print/add logs
86
+ tb_writer.add_scalar("Train/lr", float(optimizer.param_groups[0]["lr"]), epoch_i+1)
87
+ for k, v in loss_meters.items():
88
+ tb_writer.add_scalar("Train/{}".format(k), v.avg, epoch_i+1)
89
+
90
+ to_write = opt.train_log_txt_formatter.format(
91
+ time_str=time.strftime("%Y_%m_%d_%H_%M_%S"),
92
+ epoch=epoch_i+1,
93
+ loss_str=" ".join(["{} {:.4f}".format(k, v.avg) for k, v in loss_meters.items()]))
94
+ with open(opt.train_log_filepath, "a") as f:
95
+ f.write(to_write)
96
+
97
+ logger.info("Epoch time stats:")
98
+ for name, meter in time_meters.items():
99
+ d = {k: f"{getattr(meter, k):.4f}" for k in ["max", "min", "avg"]}
100
+ logger.info(f"{name} ==> {d}")
101
+
102
+
103
+ def train(model, criterion, optimizer, lr_scheduler, train_dataset, val_dataset, opt):
104
+ if opt.device.type == "cuda":
105
+ logger.info("CUDA enabled.")
106
+ model.to(opt.device)
107
+
108
+ tb_writer = SummaryWriter(opt.tensorboard_log_dir)
109
+ tb_writer.add_text("hyperparameters", dict_to_markdown(vars(opt), max_str_len=None))
110
+ opt.train_log_txt_formatter = "{time_str} [Epoch] {epoch:03d} [Loss] {loss_str}\n"
111
+ opt.eval_log_txt_formatter = "{time_str} [Epoch] {epoch:03d} [Loss] {loss_str} [Metrics] {eval_metrics_str}\n"
112
+
113
+
114
+ train_loader = DataLoader(
115
+ train_dataset,
116
+ collate_fn=start_end_collate,
117
+ batch_size=opt.bsz,
118
+ num_workers=opt.num_workers,
119
+ shuffle=True,
120
+ pin_memory=opt.pin_memory
121
+ )
122
+
123
+ prev_best_score = 0.
124
+ es_cnt = 0
125
+ # start_epoch = 0
126
+ if opt.start_epoch is None:
127
+ start_epoch = -1 if opt.eval_untrained else 0
128
+ else:
129
+ start_epoch = opt.start_epoch
130
+ save_submission_filename = "latest_{}_{}_preds.jsonl".format(opt.dset_name, opt.eval_split_name)
131
+ for epoch_i in trange(start_epoch, opt.n_epoch, desc="Epoch"):
132
+ if epoch_i > -1:
133
+ train_epoch(model, criterion, train_loader, optimizer, opt, epoch_i, tb_writer)
134
+ lr_scheduler.step()
135
+ eval_epoch_interval = opt.eval_epoch
136
+ if opt.eval_path is not None and (epoch_i + 1) % eval_epoch_interval == 0:
137
+ with torch.no_grad():
138
+ metrics_no_nms, metrics_nms, eval_loss_meters, latest_file_paths = \
139
+ eval_epoch(model, val_dataset, opt, save_submission_filename, epoch_i, criterion, tb_writer)
140
+
141
+ # log
142
+ to_write = opt.eval_log_txt_formatter.format(
143
+ time_str=time.strftime("%Y_%m_%d_%H_%M_%S"),
144
+ epoch=epoch_i,
145
+ loss_str=" ".join(["{} {:.4f}".format(k, v.avg) for k, v in eval_loss_meters.items()]),
146
+ eval_metrics_str=json.dumps(metrics_no_nms))
147
+
148
+ with open(opt.eval_log_filepath, "a") as f:
149
+ f.write(to_write)
150
+ logger.info("metrics_no_nms {}".format(pprint.pformat(metrics_no_nms["brief"], indent=4)))
151
+ if metrics_nms is not None:
152
+ logger.info("metrics_nms {}".format(pprint.pformat(metrics_nms["brief"], indent=4)))
153
+
154
+ metrics = metrics_no_nms
155
+ for k, v in metrics["brief"].items():
156
+ tb_writer.add_scalar(f"Eval/{k}", float(v), epoch_i+1)
157
+
158
+ if opt.dset_name in ['hl']:
159
+ stop_score = metrics["brief"]["MR-full-mAP"]
160
+ else:
161
+ stop_score = (metrics["brief"]["[email protected]"] + metrics["brief"]["[email protected]"]) / 2
162
+
163
+
164
+ if stop_score > prev_best_score:
165
+ es_cnt = 0
166
+ prev_best_score = stop_score
167
+
168
+ checkpoint = {
169
+ "model": model.state_dict(),
170
+ "optimizer": optimizer.state_dict(),
171
+ "lr_scheduler": lr_scheduler.state_dict(),
172
+ "epoch": epoch_i,
173
+ "opt": opt
174
+ }
175
+ torch.save(checkpoint, opt.ckpt_filepath.replace(".ckpt", "_best.ckpt"))
176
+
177
+ best_file_paths = [e.replace("latest", "best") for e in latest_file_paths]
178
+ for src, tgt in zip(latest_file_paths, best_file_paths):
179
+ os.renames(src, tgt)
180
+ logger.info("The checkpoint file has been updated.")
181
+ else:
182
+ es_cnt += 1
183
+ if opt.max_es_cnt != -1 and es_cnt > opt.max_es_cnt: # early stop
184
+ with open(opt.train_log_filepath, "a") as f:
185
+ f.write(f"Early Stop at epoch {epoch_i}")
186
+ logger.info(f"\n>>>>> Early stop at epoch {epoch_i} {prev_best_score}\n")
187
+ break
188
+
189
+ # save ckpt
190
+ checkpoint = {
191
+ "model": model.state_dict(),
192
+ "optimizer": optimizer.state_dict(),
193
+ "lr_scheduler": lr_scheduler.state_dict(),
194
+ "epoch": epoch_i,
195
+ "opt": opt
196
+ }
197
+ torch.save(checkpoint, opt.ckpt_filepath.replace(".ckpt", "_latest.ckpt"))
198
+
199
+ # save_interval = 10 if "subs_train" in opt.train_path else 50 # smaller for pretrain
200
+ # if (epoch_i + 1) % save_interval == 0 or (epoch_i + 1) % opt.lr_drop == 0: # additional copies
201
+ # checkpoint = {
202
+ # "model": model.state_dict(),
203
+ # "optimizer": optimizer.state_dict(),
204
+ # "epoch": epoch_i,
205
+ # "opt": opt
206
+ # }
207
+ # torch.save(checkpoint, opt.ckpt_filepath.replace(".ckpt", f"_e{epoch_i:04d}.ckpt"))
208
+
209
+ if opt.debug:
210
+ break
211
+
212
+ tb_writer.close()
213
+
214
+
215
+
216
+ def start_training():
217
+ logger.info("Setup config, data and model...")
218
+ opt = BaseOptions().parse()
219
+ set_seed(opt.seed)
220
+ if opt.debug: # keep the model run deterministically
221
+ # 'cudnn.benchmark = True' enabled auto finding the best algorithm for a specific input/net config.
222
+ # Enable this only when input size is fixed.
223
+ cudnn.benchmark = False
224
+ cudnn.deterministic = True
225
+
226
+
227
+ dataset_config = dict(
228
+ dset_name=opt.dset_name,
229
+ data_path=opt.train_path,
230
+ v_feat_dirs=opt.v_feat_dirs,
231
+ q_feat_dir=opt.t_feat_dir,
232
+ q_feat_type="last_hidden_state",
233
+ max_q_l=opt.max_q_l,
234
+ max_v_l=opt.max_v_l,
235
+ ctx_mode=opt.ctx_mode,
236
+ data_ratio=opt.data_ratio,
237
+ normalize_v=not opt.no_norm_vfeat,
238
+ normalize_t=not opt.no_norm_tfeat,
239
+ clip_len=opt.clip_length,
240
+ max_windows=opt.max_windows,
241
+ span_loss_type=opt.span_loss_type,
242
+ txt_drop_ratio=opt.txt_drop_ratio,
243
+ dset_domain=opt.dset_domain,
244
+ )
245
+ dataset_config["data_path"] = opt.train_path
246
+ train_dataset = StartEndDataset(**dataset_config)
247
+ # import pdb; pdb.set_trace()
248
+ # train_dataset[0]
249
+
250
+ if opt.eval_path is not None:
251
+ dataset_config["data_path"] = opt.eval_path
252
+ dataset_config["txt_drop_ratio"] = 0
253
+ dataset_config["q_feat_dir"] = opt.t_feat_dir.replace("sub_features", "text_features") # for pretraining
254
+ # dataset_config["load_labels"] = False # uncomment to calculate eval loss
255
+
256
+ eval_dataset = StartEndDataset(**dataset_config)
257
+
258
+ else:
259
+ eval_dataset = None
260
+
261
+ model, criterion, optimizer, lr_scheduler = setup_model(opt)
262
+ logger.info(f"Model {model}")
263
+ count_parameters(model)
264
+ logger.info("Start Training...")
265
+
266
+ train(model, criterion, optimizer, lr_scheduler, train_dataset, eval_dataset, opt)
267
+
268
+ return opt.ckpt_filepath.replace(".ckpt", "_best.ckpt"), opt.eval_split_name, opt.eval_path, opt.debug, opt
269
+
270
+
271
+ if __name__ == '__main__':
272
+ best_ckpt_path, eval_split_name, eval_path, debug, opt = start_training()
273
+ if not debug:
274
+ input_args = ["--resume", best_ckpt_path,
275
+ "--eval_split_name", eval_split_name,
276
+ "--eval_path", eval_path]
277
+
278
+ import sys
279
+ sys.argv[1:] = input_args
280
+ logger.info("\n\n\nFINISHED TRAINING!!!")
281
+ logger.info("Evaluating model at {}".format(best_ckpt_path))
282
+ logger.info("Input args {}".format(sys.argv[1:]))
283
+ start_inference(opt)
third_party/cgdetr/cg_detr/transformer.py ADDED
@@ -0,0 +1,871 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+ """
3
+ DETR Transformer class.
4
+
5
+ Copy-paste from torch.nn.Transformer with modifications:
6
+ * positional encodings are passed in MHattention
7
+ * extra LN at the end of encoder is removed
8
+ * decoder returns a stack of activations from all decoding layers
9
+ """
10
+ import copy
11
+ from typing import Optional
12
+ import torch
13
+ import torch.nn.functional as F
14
+ from torch import nn, Tensor
15
+ import math
16
+ import numpy as np
17
+ from .attention import MultiheadAttention
18
+ from .crossattention import MultiheadAttention as cateattention
19
+
20
+ class MLP(nn.Module):
21
+ """ Very simple multi-layer perceptron (also called FFN)"""
22
+
23
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
24
+ super().__init__()
25
+ self.num_layers = num_layers
26
+ h = [hidden_dim] * (num_layers - 1)
27
+ self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
28
+
29
+ def forward(self, x):
30
+ for i, layer in enumerate(self.layers):
31
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
32
+ return x
33
+
34
+ def inverse_sigmoid(x, eps=1e-3):
35
+ x = x.clamp(min=0, max=1)
36
+ x1 = x.clamp(min=eps)
37
+ x2 = (1 - x).clamp(min=eps)
38
+ return torch.log(x1/x2)
39
+
40
+ def gen_sineembed_for_position(pos_tensor, d_model):
41
+ # n_query, bs, _ = pos_tensor.size()
42
+ # sineembed_tensor = torch.zeros(n_query, bs, 256)
43
+ scale = 2 * math.pi
44
+ dim_t = torch.arange(d_model//2, dtype=torch.float32, device=pos_tensor.device)
45
+ dim_t = 10000 ** (2 * (dim_t // 2) / (d_model//2))
46
+ center_embed = pos_tensor[:, :, 0] * scale
47
+ pos_x = center_embed[:, :, None] / dim_t
48
+ pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2)
49
+
50
+ span_embed = pos_tensor[:, :, 1] * scale
51
+ pos_w = span_embed[:, :, None] / dim_t
52
+ pos_w = torch.stack((pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3).flatten(2)
53
+
54
+ pos = torch.cat((pos_x, pos_w), dim=2)
55
+ return pos
56
+
57
+ class Transformer(nn.Module):
58
+
59
+ def __init__(self, d_model=512, nhead=8, num_queries=2, num_encoder_layers=6,
60
+ num_decoder_layers=6, dim_feedforward=2048, dropout=0.1,
61
+ activation="relu", normalize_before=False,
62
+ return_intermediate_dec=False, query_dim=2,
63
+ keep_query_pos=False, query_scale_type='cond_elewise',
64
+ num_patterns=0,
65
+ modulate_t_attn=True,
66
+ bbox_embed_diff_each_layer=False, args=None
67
+ ):
68
+ super().__init__()
69
+ self.args = args
70
+ mcls_encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,
71
+ dropout, activation, normalize_before)
72
+ mcls_encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
73
+ self.mcls_encoder = TransformerEncoder(mcls_encoder_layer, args.moment_layers, mcls_encoder_norm)
74
+
75
+ t2v_encoder_layer = T2V_TransformerEncoderLayer(d_model, nhead, dim_feedforward,
76
+ dropout, activation, normalize_before, self.args.num_dummies)
77
+ encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
78
+ self.t2v_encoder = TransformerCATEEncoder(t2v_encoder_layer, args.t2v_layers, encoder_norm)
79
+
80
+ encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,
81
+ dropout, activation, normalize_before)
82
+ encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
83
+ self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
84
+
85
+ decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward,
86
+ dropout, activation, normalize_before, keep_query_pos=keep_query_pos)
87
+ decoder_norm = nn.LayerNorm(d_model)
88
+ self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm,
89
+ return_intermediate=return_intermediate_dec,
90
+ d_model=d_model, query_dim=query_dim, keep_query_pos=keep_query_pos, query_scale_type=query_scale_type,
91
+ modulate_t_attn=modulate_t_attn,
92
+ bbox_embed_diff_each_layer=bbox_embed_diff_each_layer)
93
+
94
+ self._reset_parameters()
95
+
96
+ self.d_model = d_model
97
+ self.nhead = nhead
98
+ self.dec_layers = num_decoder_layers
99
+ self.num_queries = num_queries
100
+ self.num_patterns = num_patterns
101
+
102
+ def _reset_parameters(self):
103
+ for p in self.parameters():
104
+ if p.dim() > 1:
105
+ nn.init.xavier_uniform_(p)
106
+
107
+ def forward(self, src, mask, query_embed, pos_embed, video_length=None, moment_idx=None, msrc=None, mpos=None, mmask=None,
108
+ nmsrc=None, nmpos=None, nmmask=None,
109
+ ctxtoken=None, gtoken=None, gpos=None, vlen=None):
110
+ """
111
+ Args:
112
+ src: (batch_size, L, d)
113
+ mask: (batch_size, L)
114
+ query_embed: (#queries, d)
115
+ pos_embed: (batch_size, L, d) the same as src
116
+ video length: feature shape
117
+ vlen: actual video length
118
+ Returns:
119
+ """
120
+ # moment token
121
+ device = ctxtoken.device
122
+ if msrc is not None:
123
+ msrc = msrc.permute(1, 0, 2) # (L, batch_size, d)
124
+ mpos = mpos.permute(1, 0, 2) # (L, batch_size, d)
125
+ mmemory = self.mcls_encoder(msrc, src_key_padding_mask=mmask, pos=mpos) # (L, batch_size, d)
126
+ mmemory_moment, mmemory_frames = mmemory[0], mmemory[1:]
127
+ else:
128
+ mmemory_moment = None
129
+ mmemory_frames = None
130
+ if nmsrc is not None:
131
+ nmsrc = nmsrc.permute(1, 0, 2) # (L, batch_size, d)
132
+ nmpos = nmpos.permute(1, 0, 2) # (L, batch_size, d)
133
+ nmmemory = self.mcls_encoder(nmsrc, src_key_padding_mask=nmmask, pos=nmpos) # (L, batch_size, d)
134
+ nmmemory_moment, nmmemory_frames = nmmemory[0], nmmemory[1:]
135
+ else:
136
+ nmmemory_moment = None
137
+ nmmemory_frames = None
138
+
139
+ # flatten NxCxHxW to HWxNxC
140
+ bs, l, d = src.shape
141
+ src = src.permute(1, 0, 2) # (L, batch_size, d)
142
+ pos_embed = pos_embed.permute(1, 0, 2) # (L, batch_size, d)
143
+ refpoint_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) # (#queries, batch_size, d)
144
+
145
+ # import pdb; pdb.set_trace()
146
+ # print(src.dtype)
147
+ t2v_src, attn_weights = self.t2v_encoder(src, src_key_padding_mask=mask, pos=pos_embed, video_length=video_length) # (L, batch_size, d)
148
+
149
+ # Saliency Token
150
+ ## Context
151
+ ctx_src_ = ctxtoken.permute(1, 0, 2) # L b d
152
+
153
+ ## Distribution Token with 10 prompt tokens
154
+ ### Video Clip featre - context (avg) --> Find top 10 similar tokens --> weighted sum
155
+ # import pdb; pdb.set_trace()
156
+ fr_token_sim = torch.softmax(torch.matmul(F.normalize((src[:video_length] - ctx_src_).permute(1, 0, 2), dim=2), F.normalize(gtoken, dim=1).T), dim=-1)# src : b 75 d, token : 10 x d --> b 75 10
157
+ ### Calculate clip importance
158
+ frame_importance = attn_weights[:, :, self.args.num_dummies:].sum(2).clone().detach() # b 75
159
+ ### Masking empty clips
160
+ for i in range(len(frame_importance)):
161
+ frame_importance[i][vlen[i]:] *= 0.
162
+ ### Normalize
163
+ frame_importance = (frame_importance / frame_importance.sum(1).unsqueeze(1)) * frame_importance.size(1) # b 75
164
+ ### Scale the similarity with importance
165
+ fr_token_sim = fr_token_sim * frame_importance.unsqueeze(2).repeat(1, 1, fr_token_sim.size(2)) # b 75 10
166
+ fr_token_sim = fr_token_sim.mean(1) # b 10
167
+ topk_val, topkidx = torch.topk(fr_token_sim, k=self.args.num_prompts, dim=1)
168
+ src_ = torch.zeros((len(fr_token_sim), self.d_model), dtype=torch.bfloat16).to(device)
169
+ for i in range(len(fr_token_sim)):
170
+ src_[i] = (topk_val[i].unsqueeze(1) * gtoken[topkidx[i]]).sum(0)
171
+ src_ = src_.reshape(1, src.size(1), -1)
172
+
173
+ ## Add context and distribution token
174
+ src_ = src_ + ctx_src_
175
+ pos_ = gpos.reshape([1, 1, self.d_model]).repeat(1, pos_embed.shape[1], 1)
176
+ mask_ = torch.tensor([[False]]).to(mask.device).repeat(mask.shape[0], 1)
177
+
178
+ # import pdb; pdb.set_trace()
179
+ src_, _ = self.t2v_encoder(src_, src_key_padding_mask=mask_, pos=pos_,
180
+ video_length=video_length, dummy=False) # (L, batch_size, d)
181
+
182
+ src = torch.cat([src_, t2v_src], dim=0)
183
+ mask = torch.cat([mask_, mask], dim=1)
184
+ pos_embed = torch.cat([pos_, pos_embed], dim=0)
185
+
186
+ src = src[:video_length + 1]
187
+ mask = mask[:, :video_length + 1]
188
+ pos_embed = pos_embed[:video_length + 1]
189
+
190
+ memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) # (L, batch_size, d)
191
+ memory_global, memory_local = memory[0], memory[1:]
192
+ memory_local += memory_global.unsqueeze(0).repeat(memory_local.size(0), 1, 1)
193
+ mask_local = mask[:, 1:]
194
+ pos_embed_local = pos_embed[1:]
195
+
196
+ tgt = torch.zeros(refpoint_embed.shape[0], bs, d).to(device)
197
+ tgt = tgt.type(torch.bfloat16)
198
+
199
+ # import pdb; pdb.set_trace()
200
+ hs, references = self.decoder(tgt, memory_local, memory_key_padding_mask=mask_local, pos=pos_embed_local, refpoints_unsigmoid=refpoint_embed) # (#layers, #queries, batch_size, d)
201
+ memory_local = memory_local.transpose(0, 1) # (batch_size, L, d)
202
+
203
+ return hs, references, memory_local, memory_global, attn_weights, mmemory_moment, nmmemory_moment, mmemory_frames, nmmemory_frames
204
+
205
+
206
+ class TransformerCATEEncoder(nn.Module):
207
+ def __init__(self, encoder_layer, num_layers, norm=None, return_intermediate=False):
208
+ super().__init__()
209
+ self.layers = _get_clones(encoder_layer, num_layers)
210
+ self.num_layers = num_layers
211
+ self.norm = norm
212
+ self.return_intermediate = return_intermediate
213
+
214
+ def forward(self, src,
215
+ mask: Optional[Tensor] = None,
216
+ src_key_padding_mask: Optional[Tensor] = None,
217
+ pos: Optional[Tensor] = None,
218
+ dummy=True,
219
+ **kwargs):
220
+ output = src
221
+
222
+ intermediate = []
223
+ attn_weights = None
224
+ for i, layer in enumerate(self.layers):
225
+ output, attn_weight = layer(output, src_mask=mask,
226
+ src_key_padding_mask=src_key_padding_mask, pos=pos, dummy=dummy, **kwargs)
227
+ if attn_weights is None:
228
+ attn_weights = attn_weight
229
+ else:
230
+ attn_weights = attn_weights + attn_weight
231
+ if self.return_intermediate:
232
+ intermediate.append(output)
233
+ attn_weights /= self.num_layers
234
+
235
+ if self.norm is not None:
236
+ output = self.norm(output)
237
+
238
+ if self.return_intermediate:
239
+ return torch.stack(intermediate)
240
+
241
+ return output, attn_weights
242
+
243
+ class TransformerEncoder(nn.Module):
244
+
245
+ def __init__(self, encoder_layer, num_layers, norm=None, return_intermediate=False):
246
+ super().__init__()
247
+ self.layers = _get_clones(encoder_layer, num_layers)
248
+ self.num_layers = num_layers
249
+ self.norm = norm
250
+ self.return_intermediate = return_intermediate
251
+
252
+ def forward(self, src,
253
+ mask: Optional[Tensor] = None,
254
+ src_key_padding_mask: Optional[Tensor] = None,
255
+ pos: Optional[Tensor] = None,
256
+ **kwargs):
257
+ output = src
258
+
259
+ intermediate = []
260
+
261
+ for layer in self.layers:
262
+ output = layer(output, src_mask=mask,
263
+ src_key_padding_mask=src_key_padding_mask, pos=pos, **kwargs)
264
+ if self.return_intermediate:
265
+ intermediate.append(output)
266
+
267
+ if self.norm is not None:
268
+ output = self.norm(output)
269
+
270
+ if self.return_intermediate:
271
+ return torch.stack(intermediate)
272
+
273
+ return output
274
+
275
+
276
+ class TransformerDecoder(nn.Module):
277
+
278
+ def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False,
279
+ d_model=256, query_dim=2, keep_query_pos=False, query_scale_type='cond_elewise',
280
+ modulate_t_attn=False,
281
+ bbox_embed_diff_each_layer=False,
282
+ ):
283
+ super().__init__()
284
+ self.layers = _get_clones(decoder_layer, num_layers)
285
+ self.num_layers = num_layers
286
+ self.norm = norm
287
+ self.return_intermediate = return_intermediate
288
+ assert return_intermediate
289
+ self.query_dim = query_dim
290
+
291
+ assert query_scale_type in ['cond_elewise', 'cond_scalar', 'fix_elewise']
292
+ self.query_scale_type = query_scale_type
293
+ if query_scale_type == 'cond_elewise':
294
+ self.query_scale = MLP(d_model, d_model, d_model, 2)
295
+ elif query_scale_type == 'cond_scalar':
296
+ self.query_scale = MLP(d_model, d_model, 1, 2)
297
+ elif query_scale_type == 'fix_elewise':
298
+ self.query_scale = nn.Embedding(num_layers, d_model)
299
+ else:
300
+ raise NotImplementedError("Unknown query_scale_type: {}".format(query_scale_type))
301
+
302
+ self.ref_point_head = MLP(d_model, d_model, d_model, 2)
303
+
304
+ # self.bbox_embed = None
305
+ # for DAB-detr
306
+ if bbox_embed_diff_each_layer:
307
+ self.bbox_embed = nn.ModuleList([MLP(d_model, d_model, 2, 3) for i in range(num_layers)])
308
+ else:
309
+ self.bbox_embed = MLP(d_model, d_model, 2, 3)
310
+ # init bbox_embed
311
+ if bbox_embed_diff_each_layer:
312
+ for bbox_embed in self.bbox_embed:
313
+ nn.init.constant_(bbox_embed.layers[-1].weight.data, 0)
314
+ nn.init.constant_(bbox_embed.layers[-1].bias.data, 0)
315
+ else:
316
+ nn.init.constant_(self.bbox_embed.layers[-1].weight.data, 0)
317
+ nn.init.constant_(self.bbox_embed.layers[-1].bias.data, 0)
318
+ self.d_model = d_model
319
+ self.modulate_t_attn = modulate_t_attn
320
+ self.bbox_embed_diff_each_layer = bbox_embed_diff_each_layer
321
+
322
+ if modulate_t_attn:
323
+ self.ref_anchor_head = MLP(d_model, d_model, 1, 2)
324
+
325
+ if not keep_query_pos:
326
+ for layer_id in range(num_layers - 1):
327
+ self.layers[layer_id + 1].ca_qpos_proj = None
328
+
329
+ def forward(self, tgt, memory,
330
+ tgt_mask: Optional[Tensor] = None,
331
+ memory_mask: Optional[Tensor] = None,
332
+ tgt_key_padding_mask: Optional[Tensor] = None,
333
+ memory_key_padding_mask: Optional[Tensor] = None,
334
+ pos: Optional[Tensor] = None,
335
+ refpoints_unsigmoid: Optional[Tensor] = None, # num_queries, bs, 2
336
+ ):
337
+ output = tgt
338
+
339
+ intermediate = []
340
+ reference_points = refpoints_unsigmoid.sigmoid()
341
+ ref_points = [reference_points]
342
+
343
+ # import pdb; pdb.set_trace()
344
+
345
+ for layer_id, layer in enumerate(self.layers):
346
+ obj_center = reference_points[..., :self.query_dim]
347
+ # get sine embedding for the query vector
348
+ query_sine_embed = gen_sineembed_for_position(obj_center, self.d_model)
349
+ query_sine_embed = query_sine_embed.type(torch.bfloat16)
350
+
351
+ query_pos = self.ref_point_head(query_sine_embed)
352
+ # For the first decoder layer, we do not apply transformation over p_s
353
+ if self.query_scale_type != 'fix_elewise':
354
+ if layer_id == 0:
355
+ pos_transformation = 1
356
+ else:
357
+ pos_transformation = self.query_scale(output)
358
+ else:
359
+ pos_transformation = self.query_scale.weight[layer_id]
360
+
361
+ # apply transformation
362
+ query_sine_embed = query_sine_embed * pos_transformation
363
+
364
+ # modulated HW attentions
365
+ if self.modulate_t_attn:
366
+ reft_cond = self.ref_anchor_head(output).sigmoid() # nq, bs, 1
367
+
368
+ query_sine_embed *= (reft_cond[..., 0] / obj_center[..., 1]).unsqueeze(-1)
369
+
370
+
371
+ output = layer(output, memory, tgt_mask=tgt_mask,
372
+ memory_mask=memory_mask,
373
+ tgt_key_padding_mask=tgt_key_padding_mask,
374
+ memory_key_padding_mask=memory_key_padding_mask,
375
+ pos=pos, query_pos=query_pos, query_sine_embed=query_sine_embed,
376
+ is_first=(layer_id == 0))
377
+
378
+ # iter update
379
+ if self.bbox_embed is not None:
380
+ if self.bbox_embed_diff_each_layer:
381
+ tmp = self.bbox_embed[layer_id](output)
382
+ else:
383
+ tmp = self.bbox_embed(output)
384
+ # import ipdb; ipdb.set_trace()
385
+ tmp[..., :self.query_dim] += inverse_sigmoid(reference_points)
386
+ new_reference_points = tmp[..., :self.query_dim].sigmoid()
387
+ if layer_id != self.num_layers - 1:
388
+ ref_points.append(new_reference_points)
389
+ reference_points = new_reference_points.detach()
390
+
391
+ if self.return_intermediate:
392
+ intermediate.append(self.norm(output))
393
+
394
+ if self.norm is not None:
395
+ output = self.norm(output)
396
+ if self.return_intermediate:
397
+ intermediate.pop()
398
+ intermediate.append(output)
399
+
400
+ if self.return_intermediate:
401
+ if self.bbox_embed is not None:
402
+ return [
403
+ torch.stack(intermediate).transpose(1, 2),
404
+ torch.stack(ref_points).transpose(1, 2),
405
+ ]
406
+ else:
407
+ return [
408
+ torch.stack(intermediate).transpose(1, 2),
409
+ reference_points.unsqueeze(0).transpose(1, 2)
410
+ ]
411
+
412
+ return output.unsqueeze(0)
413
+
414
+
415
+ class TransformerEncoderLayerThin(nn.Module):
416
+
417
+ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
418
+ activation="relu", normalize_before=False):
419
+ super().__init__()
420
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
421
+ # Implementation of Feedforward model
422
+ # self.linear1 = nn.Linear(d_model, dim_feedforward)
423
+ # self.dropout = nn.Dropout(dropout)
424
+ # self.linear2 = nn.Linear(dim_feedforward, d_model)
425
+ self.linear = nn.Linear(d_model, d_model)
426
+ self.norm = nn.LayerNorm(d_model)
427
+ self.dropout = nn.Dropout(dropout)
428
+
429
+ # self.activation = _get_activation_fn(activation)
430
+ self.normalize_before = normalize_before
431
+
432
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
433
+ return tensor if pos is None else tensor + pos
434
+
435
+ def forward_post(self,
436
+ src,
437
+ src_mask: Optional[Tensor] = None,
438
+ src_key_padding_mask: Optional[Tensor] = None,
439
+ pos: Optional[Tensor] = None):
440
+ q = k = self.with_pos_embed(src, pos)
441
+ src2 = self.self_attn(q, k, value=src, attn_mask=src_mask,
442
+ key_padding_mask=src_key_padding_mask)[0]
443
+ src2 = self.linear(src2)
444
+ src = src + self.dropout(src2)
445
+ src = self.norm(src)
446
+ # src = src + self.dropout1(src2)
447
+ # src = self.norm1(src)
448
+ # src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
449
+ # src = src + self.dropout2(src2)
450
+ # src = self.norm2(src)
451
+ return src
452
+
453
+ def forward_pre(self, src,
454
+ src_mask: Optional[Tensor] = None,
455
+ src_key_padding_mask: Optional[Tensor] = None,
456
+ pos: Optional[Tensor] = None):
457
+ """not used"""
458
+ src2 = self.norm1(src)
459
+ q = k = self.with_pos_embed(src2, pos)
460
+ src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask,
461
+ key_padding_mask=src_key_padding_mask)[0]
462
+ src = src + self.dropout1(src2)
463
+ src2 = self.norm2(src)
464
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
465
+ src = src + self.dropout2(src2)
466
+ return src
467
+
468
+ def forward(self, src,
469
+ src_mask: Optional[Tensor] = None,
470
+ src_key_padding_mask: Optional[Tensor] = None,
471
+ pos: Optional[Tensor] = None):
472
+ if self.normalize_before:
473
+ return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
474
+ return self.forward_post(src, src_mask, src_key_padding_mask, pos)
475
+
476
+
477
+ class T2V_TransformerEncoderLayer(nn.Module):
478
+
479
+ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
480
+ activation="relu", normalize_before=False, num_dummies=3):
481
+ super().__init__()
482
+ self.self_attn = cateattention(d_model, nhead, dropout=dropout, num_dummies=num_dummies)
483
+ # Implementation of Feedforward model
484
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
485
+ self.dropout = nn.Dropout(dropout)
486
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
487
+
488
+ self.norm1 = nn.LayerNorm(d_model)
489
+ self.norm2 = nn.LayerNorm(d_model)
490
+ self.dropout1 = DropPath(dropout)
491
+ self.dropout2 = DropPath(dropout)
492
+
493
+ self.activation = _get_activation_fn(activation)
494
+ self.normalize_before = normalize_before
495
+ self.nhead = nhead
496
+
497
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
498
+ return tensor if pos is None else tensor + pos
499
+
500
+ def forward_post(self,
501
+ src,
502
+ src_mask: Optional[Tensor] = None,
503
+ src_key_padding_mask: Optional[Tensor] = None,
504
+ pos: Optional[Tensor] = None,
505
+ video_length=None, dummy=True):
506
+ assert video_length is not None
507
+ pos_src = self.with_pos_embed(src, pos)
508
+ q, k, v = pos_src[:video_length], pos_src[video_length:], src[video_length:]
509
+
510
+ qmask, kmask = src_key_padding_mask[:, :video_length].unsqueeze(2), src_key_padding_mask[:, video_length:].unsqueeze(1)
511
+ attn_mask = torch.matmul(qmask.float(), kmask.float()).bool().repeat(self.nhead, 1, 1)
512
+
513
+ # - key_padding_mask: :math:`(S)` or :math:`(N, S)` where N is the batch size, S is the source sequence length.
514
+ # If a FloatTensor is provided, it will be directly added to the value.
515
+ # If a BoolTensor is provided, the positions with the
516
+ # value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
517
+ # - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
518
+ # 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
519
+ # S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked
520
+ # positions. If a BoolTensor is provided, positions with ``True``
521
+ # are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
522
+ # is provided, it will be added to the attention weight.
523
+ # print(q.shape, k.shape, v.shape, attn_mask.shape, src_key_padding_mask[:, video_length + 1:].shape)
524
+
525
+ # import pdb; pdb.set_trace()
526
+ src2, attn_weights = self.self_attn(q, k, v, attn_mask=attn_mask, key_padding_mask=src_key_padding_mask[:, video_length:], dummy=dummy)
527
+
528
+ src2 = src[:video_length] + self.dropout1(src2)
529
+ src3 = self.norm1(src2)
530
+ src3 = self.linear2(self.dropout(self.activation(self.linear1(src3))))
531
+ src2 = src2 + self.dropout2(src3)
532
+ src2 = self.norm2(src2)
533
+
534
+ src = torch.cat([src2, src[video_length:]])
535
+ return src, attn_weights
536
+
537
+ def forward_pre(self, src,
538
+ src_mask: Optional[Tensor] = None,
539
+ src_key_padding_mask: Optional[Tensor] = None,
540
+ pos: Optional[Tensor] = None, dummy=True):
541
+ pass
542
+
543
+
544
+ def forward(self, src,
545
+ src_mask: Optional[Tensor] = None,
546
+ src_key_padding_mask: Optional[Tensor] = None,
547
+ pos: Optional[Tensor] = None, dummy=True,
548
+ **kwargs):
549
+ if self.normalize_before:
550
+ return self.forward_pre(src, src_mask, src_key_padding_mask, pos, dummy=dummy)
551
+ return self.forward_post(src, src_mask, src_key_padding_mask, pos, dummy=dummy, **kwargs)
552
+
553
+ class TransformerEncoderLayer(nn.Module):
554
+ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
555
+ activation="relu", normalize_before=False):
556
+ super().__init__()
557
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
558
+ # Implementation of Feedforward model
559
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
560
+ self.dropout = nn.Dropout(dropout)
561
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
562
+
563
+ self.norm1 = nn.LayerNorm(d_model)
564
+ self.norm2 = nn.LayerNorm(d_model)
565
+ self.dropout1 = DropPath(dropout)
566
+ self.dropout2 = DropPath(dropout)
567
+
568
+ self.activation = _get_activation_fn(activation)
569
+ self.normalize_before = normalize_before
570
+
571
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
572
+ return tensor if pos is None else tensor + pos
573
+
574
+ def forward_post(self,
575
+ src,
576
+ src_mask: Optional[Tensor] = None,
577
+ src_key_padding_mask: Optional[Tensor] = None,
578
+ pos: Optional[Tensor] = None):
579
+ q = k = self.with_pos_embed(src, pos)
580
+ src2 = self.self_attn(q, k, value=src, attn_mask=src_mask,
581
+ key_padding_mask=src_key_padding_mask)[0]
582
+ src = src + self.dropout1(src2)
583
+ src = self.norm1(src)
584
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
585
+ src = src + self.dropout2(src2)
586
+ src = self.norm2(src)
587
+ return src
588
+
589
+ def forward_pre(self, src,
590
+ src_mask: Optional[Tensor] = None,
591
+ src_key_padding_mask: Optional[Tensor] = None,
592
+ pos: Optional[Tensor] = None):
593
+ pass
594
+
595
+ def forward(self, src,
596
+ src_mask: Optional[Tensor] = None,
597
+ src_key_padding_mask: Optional[Tensor] = None,
598
+ pos: Optional[Tensor] = None):
599
+ if self.normalize_before:
600
+ return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
601
+ return self.forward_post(src, src_mask, src_key_padding_mask, pos)
602
+
603
+
604
+ class TransformerDecoderLayer(nn.Module):
605
+
606
+ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
607
+ activation="relu", normalize_before=False, keep_query_pos=False,
608
+ rm_self_attn_decoder=False):
609
+ super().__init__()
610
+ # Decoder Self-Attention
611
+ if not rm_self_attn_decoder:
612
+ self.sa_qcontent_proj = nn.Linear(d_model, d_model)
613
+ self.sa_qpos_proj = nn.Linear(d_model, d_model)
614
+ self.sa_kcontent_proj = nn.Linear(d_model, d_model)
615
+ self.sa_kpos_proj = nn.Linear(d_model, d_model)
616
+ self.sa_v_proj = nn.Linear(d_model, d_model)
617
+ self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, vdim=d_model)
618
+
619
+ self.norm1 = nn.LayerNorm(d_model)
620
+ self.dropout1 = DropPath(dropout)
621
+
622
+ # Decoder Cross-Attention
623
+ self.ca_qcontent_proj = nn.Linear(d_model, d_model)
624
+ self.ca_qpos_proj = nn.Linear(d_model, d_model)
625
+ self.ca_kcontent_proj = nn.Linear(d_model, d_model)
626
+ self.ca_kpos_proj = nn.Linear(d_model, d_model)
627
+ self.ca_v_proj = nn.Linear(d_model, d_model)
628
+ self.ca_qpos_sine_proj = nn.Linear(d_model, d_model)
629
+ self.cross_attn = MultiheadAttention(d_model * 2, nhead, dropout=dropout, vdim=d_model)
630
+
631
+ self.nhead = nhead
632
+ self.rm_self_attn_decoder = rm_self_attn_decoder
633
+
634
+ # Implementation of Feedforward model
635
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
636
+ self.dropout = nn.Dropout(dropout)
637
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
638
+
639
+ self.norm2 = nn.LayerNorm(d_model)
640
+ self.norm3 = nn.LayerNorm(d_model)
641
+ self.dropout2 = DropPath(dropout)
642
+ self.dropout3 = DropPath(dropout)
643
+
644
+ self.activation = _get_activation_fn(activation)
645
+ self.normalize_before = normalize_before
646
+ self.keep_query_pos = keep_query_pos
647
+
648
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
649
+ return tensor if pos is None else tensor + pos
650
+
651
+ def forward(self, tgt, memory,
652
+ tgt_mask: Optional[Tensor] = None,
653
+ memory_mask: Optional[Tensor] = None,
654
+ tgt_key_padding_mask: Optional[Tensor] = None,
655
+ memory_key_padding_mask: Optional[Tensor] = None,
656
+ pos: Optional[Tensor] = None,
657
+ query_pos: Optional[Tensor] = None,
658
+ query_sine_embed=None,
659
+ is_first=False):
660
+
661
+ # ========== Begin of Self-Attention =============
662
+ if not self.rm_self_attn_decoder:
663
+ # Apply projections here
664
+ # shape: num_queries x batch_size x 256
665
+ q_content = self.sa_qcontent_proj(tgt) # target is the input of the first decoder layer. zero by default.
666
+ q_pos = self.sa_qpos_proj(query_pos)
667
+ k_content = self.sa_kcontent_proj(tgt)
668
+ k_pos = self.sa_kpos_proj(query_pos)
669
+ v = self.sa_v_proj(tgt)
670
+
671
+ num_queries, bs, n_model = q_content.shape
672
+ hw, _, _ = k_content.shape
673
+
674
+ q = q_content + q_pos
675
+ k = k_content + k_pos
676
+
677
+ tgt2 = self.self_attn(q, k, value=v, attn_mask=tgt_mask,
678
+ key_padding_mask=tgt_key_padding_mask)[0]
679
+ # ========== End of Self-Attention =============
680
+
681
+ tgt = tgt + self.dropout1(tgt2)
682
+ tgt = self.norm1(tgt)
683
+
684
+ # ========== Begin of Cross-Attention =============
685
+ # Apply projections here
686
+ # shape: num_queries x batch_size x 256
687
+ q_content = self.ca_qcontent_proj(tgt)
688
+ k_content = self.ca_kcontent_proj(memory)
689
+ v = self.ca_v_proj(memory)
690
+
691
+ num_queries, bs, n_model = q_content.shape
692
+ hw, _, _ = k_content.shape
693
+
694
+ k_pos = self.ca_kpos_proj(pos)
695
+
696
+ # For the first decoder layer, we concatenate the positional embedding predicted from
697
+ # the object query (the positional embedding) into the original query (key) in DETR.
698
+ if is_first or self.keep_query_pos:
699
+ q_pos = self.ca_qpos_proj(query_pos)
700
+ q = q_content + q_pos
701
+ k = k_content + k_pos
702
+ else:
703
+ q = q_content
704
+ k = k_content
705
+
706
+ q = q.view(num_queries, bs, self.nhead, n_model // self.nhead)
707
+ query_sine_embed = self.ca_qpos_sine_proj(query_sine_embed)
708
+ query_sine_embed = query_sine_embed.view(num_queries, bs, self.nhead, n_model // self.nhead)
709
+ q = torch.cat([q, query_sine_embed], dim=3).view(num_queries, bs, n_model * 2)
710
+ k = k.view(hw, bs, self.nhead, n_model // self.nhead)
711
+ k_pos = k_pos.view(hw, bs, self.nhead, n_model // self.nhead)
712
+ k = torch.cat([k, k_pos], dim=3).view(hw, bs, n_model * 2)
713
+
714
+ tgt2 = self.cross_attn(query=q,
715
+ key=k,
716
+ value=v, attn_mask=memory_mask,
717
+ key_padding_mask=memory_key_padding_mask)[0]
718
+ # ========== End of Cross-Attention =============
719
+
720
+ tgt = tgt + self.dropout2(tgt2)
721
+ tgt = self.norm2(tgt)
722
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
723
+ tgt = tgt + self.dropout3(tgt2)
724
+ tgt = self.norm3(tgt)
725
+ return tgt
726
+
727
+
728
+ class TransformerDecoderLayerThin(nn.Module):
729
+ """removed intermediate layer"""
730
+ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
731
+ activation="relu", normalize_before=False):
732
+ super().__init__()
733
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
734
+ self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
735
+ # Implementation of Feedforward model
736
+ self.linear1 = nn.Linear(d_model, d_model)
737
+
738
+
739
+ self.norm1 = nn.LayerNorm(d_model)
740
+ self.norm2 = nn.LayerNorm(d_model)
741
+ # self.norm3 = nn.LayerNorm(d_model)
742
+ self.dropout1 = DropPath(dropout)
743
+ self.dropout2 = DropPath(dropout)
744
+
745
+
746
+ # self.activation = _get_activation_fn(activation)
747
+ self.normalize_before = normalize_before
748
+
749
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
750
+ return tensor if pos is None else tensor + pos
751
+
752
+ def forward_post(self, tgt, memory,
753
+ tgt_mask: Optional[Tensor] = None,
754
+ memory_mask: Optional[Tensor] = None,
755
+ tgt_key_padding_mask: Optional[Tensor] = None,
756
+ memory_key_padding_mask: Optional[Tensor] = None,
757
+ pos: Optional[Tensor] = None,
758
+ query_pos: Optional[Tensor] = None):
759
+ q = k = self.with_pos_embed(tgt, query_pos)
760
+ tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,
761
+ key_padding_mask=tgt_key_padding_mask)[0]
762
+ tgt = tgt + self.dropout1(tgt2)
763
+ tgt = self.norm1(tgt)
764
+ tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
765
+ key=self.with_pos_embed(memory, pos),
766
+ value=memory, attn_mask=memory_mask,
767
+ key_padding_mask=memory_key_padding_mask)[0]
768
+ tgt2 = self.linear1(tgt2)
769
+ tgt = tgt + self.dropout2(tgt2)
770
+ tgt = self.norm2(tgt)
771
+ return tgt
772
+
773
+ def forward_pre(self, tgt, memory,
774
+ tgt_mask: Optional[Tensor] = None,
775
+ memory_mask: Optional[Tensor] = None,
776
+ tgt_key_padding_mask: Optional[Tensor] = None,
777
+ memory_key_padding_mask: Optional[Tensor] = None,
778
+ pos: Optional[Tensor] = None,
779
+ query_pos: Optional[Tensor] = None):
780
+ tgt2 = self.norm1(tgt)
781
+ q = k = self.with_pos_embed(tgt2, query_pos)
782
+ tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
783
+ key_padding_mask=tgt_key_padding_mask)[0]
784
+ tgt = tgt + self.dropout1(tgt2)
785
+ tgt2 = self.norm2(tgt)
786
+ tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
787
+ key=self.with_pos_embed(memory, pos),
788
+ value=memory, attn_mask=memory_mask,
789
+ key_padding_mask=memory_key_padding_mask)[0]
790
+ tgt = tgt + self.dropout2(tgt2)
791
+ tgt2 = self.norm3(tgt)
792
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
793
+ tgt = tgt + self.dropout3(tgt2)
794
+ return tgt
795
+
796
+ def forward(self, tgt, memory,
797
+ tgt_mask: Optional[Tensor] = None,
798
+ memory_mask: Optional[Tensor] = None,
799
+ tgt_key_padding_mask: Optional[Tensor] = None,
800
+ memory_key_padding_mask: Optional[Tensor] = None,
801
+ pos: Optional[Tensor] = None,
802
+ query_pos: Optional[Tensor] = None):
803
+ if self.normalize_before:
804
+ return self.forward_pre(tgt, memory, tgt_mask, memory_mask,
805
+ tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
806
+ return self.forward_post(tgt, memory, tgt_mask, memory_mask,
807
+ tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
808
+
809
+
810
+
811
+ def _get_clones(module, N):
812
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
813
+
814
+
815
+ def build_transformer(args):
816
+ return Transformer(
817
+ d_model=args.hidden_dim,
818
+ dropout=args.dropout,
819
+ nhead=args.nheads,
820
+ dim_feedforward=args.dim_feedforward,
821
+ num_encoder_layers=args.enc_layers,
822
+ num_decoder_layers=args.dec_layers,
823
+ normalize_before=args.pre_norm,
824
+ return_intermediate_dec=True,
825
+ activation='prelu',
826
+ args=args
827
+ )
828
+
829
+ def drop_path(x, drop_prob=0.0, training=False):
830
+ """
831
+ Stochastic Depth per sample.
832
+ """
833
+ if drop_prob == 0.0 or not training:
834
+ return x
835
+
836
+ keep_prob = 1 - drop_prob
837
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1)
838
+ mask = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
839
+ mask.floor_()
840
+ x = x.div(keep_prob) * mask
841
+
842
+ return x
843
+
844
+ class DropPath(nn.Module):
845
+ """
846
+ Drop paths per sample (when applied in main path of residual blocks).
847
+ """
848
+
849
+ def __init__(self, drop_prob=None):
850
+ super(DropPath, self).__init__()
851
+
852
+ self.drop_prob = drop_prob
853
+
854
+ def forward(self, x):
855
+ x = x.permute(1, 0, 2)
856
+ res = drop_path(x, self.drop_prob, self.training)
857
+ return res.permute(1, 0, 2)
858
+
859
+ def _get_activation_fn(activation):
860
+ """Return an activation function given a string"""
861
+ if activation == "relu":
862
+ return F.relu
863
+ if activation == "gelu":
864
+ return F.gelu
865
+ if activation == "glu":
866
+ return F.glu
867
+ if activation == "prelu":
868
+ return nn.PReLU()
869
+ if activation == "selu":
870
+ return F.selu
871
+ raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
third_party/cgdetr/data/LICENSE ADDED
@@ -0,0 +1,437 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Attribution-NonCommercial-ShareAlike 4.0 International
2
+
3
+ =======================================================================
4
+
5
+ Creative Commons Corporation ("Creative Commons") is not a law firm and
6
+ does not provide legal services or legal advice. Distribution of
7
+ Creative Commons public licenses does not create a lawyer-client or
8
+ other relationship. Creative Commons makes its licenses and related
9
+ information available on an "as-is" basis. Creative Commons gives no
10
+ warranties regarding its licenses, any material licensed under their
11
+ terms and conditions, or any related information. Creative Commons
12
+ disclaims all liability for damages resulting from their use to the
13
+ fullest extent possible.
14
+
15
+ Using Creative Commons Public Licenses
16
+
17
+ Creative Commons public licenses provide a standard set of terms and
18
+ conditions that creators and other rights holders may use to share
19
+ original works of authorship and other material subject to copyright
20
+ and certain other rights specified in the public license below. The
21
+ following considerations are for informational purposes only, are not
22
+ exhaustive, and do not form part of our licenses.
23
+
24
+ Considerations for licensors: Our public licenses are
25
+ intended for use by those authorized to give the public
26
+ permission to use material in ways otherwise restricted by
27
+ copyright and certain other rights. Our licenses are
28
+ irrevocable. Licensors should read and understand the terms
29
+ and conditions of the license they choose before applying it.
30
+ Licensors should also secure all rights necessary before
31
+ applying our licenses so that the public can reuse the
32
+ material as expected. Licensors should clearly mark any
33
+ material not subject to the license. This includes other CC-
34
+ licensed material, or material used under an exception or
35
+ limitation to copyright. More considerations for licensors:
36
+ wiki.creativecommons.org/Considerations_for_licensors
37
+
38
+ Considerations for the public: By using one of our public
39
+ licenses, a licensor grants the public permission to use the
40
+ licensed material under specified terms and conditions. If
41
+ the licensor's permission is not necessary for any reason--for
42
+ example, because of any applicable exception or limitation to
43
+ copyright--then that use is not regulated by the license. Our
44
+ licenses grant only permissions under copyright and certain
45
+ other rights that a licensor has authority to grant. Use of
46
+ the licensed material may still be restricted for other
47
+ reasons, including because others have copyright or other
48
+ rights in the material. A licensor may make special requests,
49
+ such as asking that all changes be marked or described.
50
+ Although not required by our licenses, you are encouraged to
51
+ respect those requests where reasonable. More considerations
52
+ for the public:
53
+ wiki.creativecommons.org/Considerations_for_licensees
54
+
55
+ =======================================================================
56
+
57
+ Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International
58
+ Public License
59
+
60
+ By exercising the Licensed Rights (defined below), You accept and agree
61
+ to be bound by the terms and conditions of this Creative Commons
62
+ Attribution-NonCommercial-ShareAlike 4.0 International Public License
63
+ ("Public License"). To the extent this Public License may be
64
+ interpreted as a contract, You are granted the Licensed Rights in
65
+ consideration of Your acceptance of these terms and conditions, and the
66
+ Licensor grants You such rights in consideration of benefits the
67
+ Licensor receives from making the Licensed Material available under
68
+ these terms and conditions.
69
+
70
+
71
+ Section 1 -- Definitions.
72
+
73
+ a. Adapted Material means material subject to Copyright and Similar
74
+ Rights that is derived from or based upon the Licensed Material
75
+ and in which the Licensed Material is translated, altered,
76
+ arranged, transformed, or otherwise modified in a manner requiring
77
+ permission under the Copyright and Similar Rights held by the
78
+ Licensor. For purposes of this Public License, where the Licensed
79
+ Material is a musical work, performance, or sound recording,
80
+ Adapted Material is always produced where the Licensed Material is
81
+ synched in timed relation with a moving image.
82
+
83
+ b. Adapter's License means the license You apply to Your Copyright
84
+ and Similar Rights in Your contributions to Adapted Material in
85
+ accordance with the terms and conditions of this Public License.
86
+
87
+ c. BY-NC-SA Compatible License means a license listed at
88
+ creativecommons.org/compatiblelicenses, approved by Creative
89
+ Commons as essentially the equivalent of this Public License.
90
+
91
+ d. Copyright and Similar Rights means copyright and/or similar rights
92
+ closely related to copyright including, without limitation,
93
+ performance, broadcast, sound recording, and Sui Generis Database
94
+ Rights, without regard to how the rights are labeled or
95
+ categorized. For purposes of this Public License, the rights
96
+ specified in Section 2(b)(1)-(2) are not Copyright and Similar
97
+ Rights.
98
+
99
+ e. Effective Technological Measures means those measures that, in the
100
+ absence of proper authority, may not be circumvented under laws
101
+ fulfilling obligations under Article 11 of the WIPO Copyright
102
+ Treaty adopted on December 20, 1996, and/or similar international
103
+ agreements.
104
+
105
+ f. Exceptions and Limitations means fair use, fair dealing, and/or
106
+ any other exception or limitation to Copyright and Similar Rights
107
+ that applies to Your use of the Licensed Material.
108
+
109
+ g. License Elements means the license attributes listed in the name
110
+ of a Creative Commons Public License. The License Elements of this
111
+ Public License are Attribution, NonCommercial, and ShareAlike.
112
+
113
+ h. Licensed Material means the artistic or literary work, database,
114
+ or other material to which the Licensor applied this Public
115
+ License.
116
+
117
+ i. Licensed Rights means the rights granted to You subject to the
118
+ terms and conditions of this Public License, which are limited to
119
+ all Copyright and Similar Rights that apply to Your use of the
120
+ Licensed Material and that the Licensor has authority to license.
121
+
122
+ j. Licensor means the individual(s) or entity(ies) granting rights
123
+ under this Public License.
124
+
125
+ k. NonCommercial means not primarily intended for or directed towards
126
+ commercial advantage or monetary compensation. For purposes of
127
+ this Public License, the exchange of the Licensed Material for
128
+ other material subject to Copyright and Similar Rights by digital
129
+ file-sharing or similar means is NonCommercial provided there is
130
+ no payment of monetary compensation in connection with the
131
+ exchange.
132
+
133
+ l. Share means to provide material to the public by any means or
134
+ process that requires permission under the Licensed Rights, such
135
+ as reproduction, public display, public performance, distribution,
136
+ dissemination, communication, or importation, and to make material
137
+ available to the public including in ways that members of the
138
+ public may access the material from a place and at a time
139
+ individually chosen by them.
140
+
141
+ m. Sui Generis Database Rights means rights other than copyright
142
+ resulting from Directive 96/9/EC of the European Parliament and of
143
+ the Council of 11 March 1996 on the legal protection of databases,
144
+ as amended and/or succeeded, as well as other essentially
145
+ equivalent rights anywhere in the world.
146
+
147
+ n. You means the individual or entity exercising the Licensed Rights
148
+ under this Public License. Your has a corresponding meaning.
149
+
150
+
151
+ Section 2 -- Scope.
152
+
153
+ a. License grant.
154
+
155
+ 1. Subject to the terms and conditions of this Public License,
156
+ the Licensor hereby grants You a worldwide, royalty-free,
157
+ non-sublicensable, non-exclusive, irrevocable license to
158
+ exercise the Licensed Rights in the Licensed Material to:
159
+
160
+ a. reproduce and Share the Licensed Material, in whole or
161
+ in part, for NonCommercial purposes only; and
162
+
163
+ b. produce, reproduce, and Share Adapted Material for
164
+ NonCommercial purposes only.
165
+
166
+ 2. Exceptions and Limitations. For the avoidance of doubt, where
167
+ Exceptions and Limitations apply to Your use, this Public
168
+ License does not apply, and You do not need to comply with
169
+ its terms and conditions.
170
+
171
+ 3. Term. The term of this Public License is specified in Section
172
+ 6(a).
173
+
174
+ 4. Media and formats; technical modifications allowed. The
175
+ Licensor authorizes You to exercise the Licensed Rights in
176
+ all media and formats whether now known or hereafter created,
177
+ and to make technical modifications necessary to do so. The
178
+ Licensor waives and/or agrees not to assert any right or
179
+ authority to forbid You from making technical modifications
180
+ necessary to exercise the Licensed Rights, including
181
+ technical modifications necessary to circumvent Effective
182
+ Technological Measures. For purposes of this Public License,
183
+ simply making modifications authorized by this Section 2(a)
184
+ (4) never produces Adapted Material.
185
+
186
+ 5. Downstream recipients.
187
+
188
+ a. Offer from the Licensor -- Licensed Material. Every
189
+ recipient of the Licensed Material automatically
190
+ receives an offer from the Licensor to exercise the
191
+ Licensed Rights under the terms and conditions of this
192
+ Public License.
193
+
194
+ b. Additional offer from the Licensor -- Adapted Material.
195
+ Every recipient of Adapted Material from You
196
+ automatically receives an offer from the Licensor to
197
+ exercise the Licensed Rights in the Adapted Material
198
+ under the conditions of the Adapter's License You apply.
199
+
200
+ c. No downstream restrictions. You may not offer or impose
201
+ any additional or different terms or conditions on, or
202
+ apply any Effective Technological Measures to, the
203
+ Licensed Material if doing so restricts exercise of the
204
+ Licensed Rights by any recipient of the Licensed
205
+ Material.
206
+
207
+ 6. No endorsement. Nothing in this Public License constitutes or
208
+ may be construed as permission to assert or imply that You
209
+ are, or that Your use of the Licensed Material is, connected
210
+ with, or sponsored, endorsed, or granted official status by,
211
+ the Licensor or others designated to receive attribution as
212
+ provided in Section 3(a)(1)(A)(i).
213
+
214
+ b. Other rights.
215
+
216
+ 1. Moral rights, such as the right of integrity, are not
217
+ licensed under this Public License, nor are publicity,
218
+ privacy, and/or other similar personality rights; however, to
219
+ the extent possible, the Licensor waives and/or agrees not to
220
+ assert any such rights held by the Licensor to the limited
221
+ extent necessary to allow You to exercise the Licensed
222
+ Rights, but not otherwise.
223
+
224
+ 2. Patent and trademark rights are not licensed under this
225
+ Public License.
226
+
227
+ 3. To the extent possible, the Licensor waives any right to
228
+ collect royalties from You for the exercise of the Licensed
229
+ Rights, whether directly or through a collecting society
230
+ under any voluntary or waivable statutory or compulsory
231
+ licensing scheme. In all other cases the Licensor expressly
232
+ reserves any right to collect such royalties, including when
233
+ the Licensed Material is used other than for NonCommercial
234
+ purposes.
235
+
236
+
237
+ Section 3 -- License Conditions.
238
+
239
+ Your exercise of the Licensed Rights is expressly made subject to the
240
+ following conditions.
241
+
242
+ a. Attribution.
243
+
244
+ 1. If You Share the Licensed Material (including in modified
245
+ form), You must:
246
+
247
+ a. retain the following if it is supplied by the Licensor
248
+ with the Licensed Material:
249
+
250
+ i. identification of the creator(s) of the Licensed
251
+ Material and any others designated to receive
252
+ attribution, in any reasonable manner requested by
253
+ the Licensor (including by pseudonym if
254
+ designated);
255
+
256
+ ii. a copyright notice;
257
+
258
+ iii. a notice that refers to this Public License;
259
+
260
+ iv. a notice that refers to the disclaimer of
261
+ warranties;
262
+
263
+ v. a URI or hyperlink to the Licensed Material to the
264
+ extent reasonably practicable;
265
+
266
+ b. indicate if You modified the Licensed Material and
267
+ retain an indication of any previous modifications; and
268
+
269
+ c. indicate the Licensed Material is licensed under this
270
+ Public License, and include the text of, or the URI or
271
+ hyperlink to, this Public License.
272
+
273
+ 2. You may satisfy the conditions in Section 3(a)(1) in any
274
+ reasonable manner based on the medium, means, and context in
275
+ which You Share the Licensed Material. For example, it may be
276
+ reasonable to satisfy the conditions by providing a URI or
277
+ hyperlink to a resource that includes the required
278
+ information.
279
+ 3. If requested by the Licensor, You must remove any of the
280
+ information required by Section 3(a)(1)(A) to the extent
281
+ reasonably practicable.
282
+
283
+ b. ShareAlike.
284
+
285
+ In addition to the conditions in Section 3(a), if You Share
286
+ Adapted Material You produce, the following conditions also apply.
287
+
288
+ 1. The Adapter's License You apply must be a Creative Commons
289
+ license with the same License Elements, this version or
290
+ later, or a BY-NC-SA Compatible License.
291
+
292
+ 2. You must include the text of, or the URI or hyperlink to, the
293
+ Adapter's License You apply. You may satisfy this condition
294
+ in any reasonable manner based on the medium, means, and
295
+ context in which You Share Adapted Material.
296
+
297
+ 3. You may not offer or impose any additional or different terms
298
+ or conditions on, or apply any Effective Technological
299
+ Measures to, Adapted Material that restrict exercise of the
300
+ rights granted under the Adapter's License You apply.
301
+
302
+
303
+ Section 4 -- Sui Generis Database Rights.
304
+
305
+ Where the Licensed Rights include Sui Generis Database Rights that
306
+ apply to Your use of the Licensed Material:
307
+
308
+ a. for the avoidance of doubt, Section 2(a)(1) grants You the right
309
+ to extract, reuse, reproduce, and Share all or a substantial
310
+ portion of the contents of the database for NonCommercial purposes
311
+ only;
312
+
313
+ b. if You include all or a substantial portion of the database
314
+ contents in a database in which You have Sui Generis Database
315
+ Rights, then the database in which You have Sui Generis Database
316
+ Rights (but not its individual contents) is Adapted Material,
317
+ including for purposes of Section 3(b); and
318
+
319
+ c. You must comply with the conditions in Section 3(a) if You Share
320
+ all or a substantial portion of the contents of the database.
321
+
322
+ For the avoidance of doubt, this Section 4 supplements and does not
323
+ replace Your obligations under this Public License where the Licensed
324
+ Rights include other Copyright and Similar Rights.
325
+
326
+
327
+ Section 5 -- Disclaimer of Warranties and Limitation of Liability.
328
+
329
+ a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
330
+ EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
331
+ AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
332
+ ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
333
+ IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
334
+ WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
335
+ PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
336
+ ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
337
+ KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
338
+ ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
339
+
340
+ b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
341
+ TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
342
+ NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
343
+ INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
344
+ COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
345
+ USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
346
+ ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
347
+ DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
348
+ IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
349
+
350
+ c. The disclaimer of warranties and limitation of liability provided
351
+ above shall be interpreted in a manner that, to the extent
352
+ possible, most closely approximates an absolute disclaimer and
353
+ waiver of all liability.
354
+
355
+
356
+ Section 6 -- Term and Termination.
357
+
358
+ a. This Public License applies for the term of the Copyright and
359
+ Similar Rights licensed here. However, if You fail to comply with
360
+ this Public License, then Your rights under this Public License
361
+ terminate automatically.
362
+
363
+ b. Where Your right to use the Licensed Material has terminated under
364
+ Section 6(a), it reinstates:
365
+
366
+ 1. automatically as of the date the violation is cured, provided
367
+ it is cured within 30 days of Your discovery of the
368
+ violation; or
369
+
370
+ 2. upon express reinstatement by the Licensor.
371
+
372
+ For the avoidance of doubt, this Section 6(b) does not affect any
373
+ right the Licensor may have to seek remedies for Your violations
374
+ of this Public License.
375
+
376
+ c. For the avoidance of doubt, the Licensor may also offer the
377
+ Licensed Material under separate terms or conditions or stop
378
+ distributing the Licensed Material at any time; however, doing so
379
+ will not terminate this Public License.
380
+
381
+ d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
382
+ License.
383
+
384
+
385
+ Section 7 -- Other Terms and Conditions.
386
+
387
+ a. The Licensor shall not be bound by any additional or different
388
+ terms or conditions communicated by You unless expressly agreed.
389
+
390
+ b. Any arrangements, understandings, or agreements regarding the
391
+ Licensed Material not stated herein are separate from and
392
+ independent of the terms and conditions of this Public License.
393
+
394
+
395
+ Section 8 -- Interpretation.
396
+
397
+ a. For the avoidance of doubt, this Public License does not, and
398
+ shall not be interpreted to, reduce, limit, restrict, or impose
399
+ conditions on any use of the Licensed Material that could lawfully
400
+ be made without permission under this Public License.
401
+
402
+ b. To the extent possible, if any provision of this Public License is
403
+ deemed unenforceable, it shall be automatically reformed to the
404
+ minimum extent necessary to make it enforceable. If the provision
405
+ cannot be reformed, it shall be severed from this Public License
406
+ without affecting the enforceability of the remaining terms and
407
+ conditions.
408
+
409
+ c. No term or condition of this Public License will be waived and no
410
+ failure to comply consented to unless expressly agreed to by the
411
+ Licensor.
412
+
413
+ d. Nothing in this Public License constitutes or may be interpreted
414
+ as a limitation upon, or waiver of, any privileges and immunities
415
+ that apply to the Licensor or You, including from the legal
416
+ processes of any jurisdiction or authority.
417
+
418
+ =======================================================================
419
+
420
+ Creative Commons is not a party to its public
421
+ licenses. Notwithstanding, Creative Commons may elect to apply one of
422
+ its public licenses to material it publishes and in those instances
423
+ will be considered the “Licensor.” The text of the Creative Commons
424
+ public licenses is dedicated to the public domain under the CC0 Public
425
+ Domain Dedication. Except for the limited purpose of indicating that
426
+ material is shared under a Creative Commons public license or as
427
+ otherwise permitted by the Creative Commons policies published at
428
+ creativecommons.org/policies, Creative Commons does not authorize the
429
+ use of the trademark "Creative Commons" or any other trademark or logo
430
+ of Creative Commons without its prior written consent including,
431
+ without limitation, in connection with any unauthorized modifications
432
+ to any of its public licenses or any other arrangements,
433
+ understandings, or agreements concerning use of licensed material. For
434
+ the avoidance of doubt, this paragraph does not form part of the
435
+ public licenses.
436
+
437
+ Creative Commons may be contacted at creativecommons.org.
third_party/cgdetr/data/README.md ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## QVHighlights Dataset
2
+
3
+ Our annotation files include 3 splits: `train`, `val` and `test`. Each file is in [JSON Line](https://jsonlines.org/) format, each row of the files can be loaded as a single `dict` in Python. Below is an example of the annotation:
4
+
5
+ ```
6
+ {
7
+ "qid": 8737,
8
+ "query": "A family is playing basketball together on a green court outside.",
9
+ "duration": 126,
10
+ "vid": "bP5KfdFJzC4_660.0_810.0",
11
+ "relevant_windows": [[0, 16]],
12
+ "relevant_clip_ids": [0, 1, 2, 3, 4, 5, 6, 7],
13
+ "saliency_scores": [[4, 1, 1], [4, 1, 1], [4, 2, 1], [4, 3, 2], [4, 3, 2], [4, 3, 3], [4, 3, 3], [4, 3, 2]]
14
+ }
15
+ ```
16
+ `qid` is a unique identifier of a `query`. This query corresponds to a video identified by its video id `vid`. The `vid` is formatted as `{youtube_id}_{start_time}_{end_time}`. Use this information, one can retrieve the YouTube video from a url `https://www.youtube.com/embed/{youtube_id}?start={start_time}&end={end_time}&version=3`. For example, the video in this example is `https://www.youtube.com/embed/bP5KfdFJzC4?start=660&end=810&version=3`.
17
+ `duration` is an integer indicating the duration of this video.
18
+ `relevant_windows` is the list of windows that localize the moments, each window has two numbers, one indicates the start time of the moment, another one indicates the end time. `relevant_clip_ids` is the list of ids to the segmented 2-second clips that fall into the moments specified by `relevant_windows`, starting from 0.
19
+ `saliency_scores` contains the saliency scores annotations, each sublist corresponds to a clip in `relevant_clip_ids`. There are 3 elements in each sublist, they are the scores from three different annotators. A score of `4` means `Very Good`, while `0` means `Very Bad`.
20
+
21
+ Note that the three fields `relevant_clip_ids`, `relevant_windows` and `saliency_scores` for `test` split is not included. Please refer to [../standalone_eval/README.md](../standalone_eval/README.md) for details on evaluating predictions on `test`.
22
+
23
+ In addition to the annotation files, we also provided the subtitle file for our weakly supervised ASR pre-training: [subs_train.jsonl](./subs_train.jsonl). This file is formatted similarly as our annotation files, but without the `saliency_scores` entry. This file is not needed if you do not plan to pretrain models using it.
24
+
third_party/cgdetr/standalone_eval/README.md ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ QVHighlights Evaluation and Codalab Submission
2
+ ==================
3
+
4
+ ### Task Definition
5
+ Given a video and a natural language query, our task requires a system to retrieve the most relevant moments in the video, and detect the highlightness of the clips in the video.
6
+
7
+ ### Evaluation
8
+ At project root, run
9
+ ```
10
+ bash standalone_eval/eval_sample.sh
11
+ ```
12
+ This command will use [eval.py](eval.py) to evaluate the provided prediction file [sample_val_preds.jsonl](sample_val_preds.jsonl),
13
+ the output will be written into `sample_val_preds_metrics.json`.
14
+ The content in this generated file should be similar if not the same as [sample_val_preds_metrics_raw.json](sample_val_preds_metrics_raw.json) file.
15
+
16
+ ### Format
17
+
18
+ The prediction file [sample_val_preds.jsonl](sample_val_preds.jsonl) is in [JSON Line](https://jsonlines.org/) format, each row of the files can be loaded as a single `dict` in Python. Below is an example of a single line in the prediction file:
19
+ ```
20
+ {
21
+ "qid": 2579,
22
+ "query": "A girl and her mother cooked while talking with each other on facetime.",
23
+ "vid": "NUsG9BgSes0_210.0_360.0",
24
+ "pred_relevant_windows": [
25
+ [0, 70, 0.9986],
26
+ [78, 146, 0.4138],
27
+ [0, 146, 0.0444],
28
+ ...
29
+ ],
30
+ "pred_saliency_scores": [-0.2452, -0.3779, -0.4746, ...]
31
+ }
32
+
33
+ ```
34
+
35
+
36
+
37
+ | entry | description |
38
+ | --- | ----|
39
+ | `qid` | `int`, unique query id |
40
+ | `query` | `str`, natural language query, not used by the evaluation script |
41
+ | `vid` | `str`, unique video id |
42
+ | `pred_relevant_windows` | `list(list)`, moment retrieval predictions. Each sublist contains 3 elements, `[start (seconds), end (seconds), score]`|
43
+ | `pred_saliency_scores` | `list(float)`, highlight prediction scores. The higher the better. This list should contain a score for each of the 2-second clip in the videos, and is ordered. |
44
+
45
+
46
+ ### Codalab Submission
47
+ To test your model's performance on `test` split,
48
+ please submit both `val` and `test` predictions to our
49
+ [Codalab evaluation server](https://codalab.lisn.upsaclay.fr/competitions/6937).
50
+ The submission file should be a single `.zip ` file (no enclosing folder)
51
+ that contains the two prediction files
52
+ `hl_val_submission.jsonl` and `hl_test_submission.jsonl`, each of the `*submission.jsonl` file
53
+ should be formatted as instructed above.
54
+
third_party/cgdetr/standalone_eval/eval.py ADDED
@@ -0,0 +1,361 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from collections import OrderedDict, defaultdict
3
+ import json
4
+ import time
5
+ import copy
6
+ import multiprocessing as mp
7
+ from src.model.cgdetr_main.standalone_eval.utils import compute_average_precision_detection, \
8
+ compute_temporal_iou_batch_cross, compute_temporal_iou_batch_paired, load_jsonl, get_ap
9
+
10
+
11
+ def compute_average_precision_detection_wrapper(
12
+ input_triple, tiou_thresholds=np.linspace(0.5, 0.95, 10)):
13
+ qid, ground_truth, prediction = input_triple
14
+ scores = compute_average_precision_detection(
15
+ ground_truth, prediction, tiou_thresholds=tiou_thresholds)
16
+ return qid, scores
17
+
18
+
19
+ def compute_mr_ap(submission, ground_truth, iou_thds=np.linspace(0.5, 0.95, 10),
20
+ max_gt_windows=None, max_pred_windows=10, num_workers=8, chunksize=50):
21
+ iou_thds = [float(f"{e:.2f}") for e in iou_thds]
22
+ pred_qid2data = defaultdict(list)
23
+ for d in submission:
24
+ pred_windows = d["pred_relevant_windows"][:max_pred_windows] \
25
+ if max_pred_windows is not None else d["pred_relevant_windows"]
26
+ qid = d["qid"]
27
+ for w in pred_windows:
28
+ pred_qid2data[qid].append({
29
+ "video-id": d["qid"], # in order to use the API
30
+ "t-start": w[0],
31
+ "t-end": w[1],
32
+ "score": w[2]
33
+ })
34
+
35
+ gt_qid2data = defaultdict(list)
36
+ for d in ground_truth:
37
+ gt_windows = d["relevant_windows"][:max_gt_windows] \
38
+ if max_gt_windows is not None else d["relevant_windows"]
39
+ qid = d["qid"]
40
+ for w in gt_windows:
41
+ gt_qid2data[qid].append({
42
+ "video-id": d["qid"],
43
+ "t-start": w[0],
44
+ "t-end": w[1]
45
+ })
46
+ qid2ap_list = {}
47
+ # start_time = time.time()
48
+ data_triples = [[qid, gt_qid2data[qid], pred_qid2data[qid]] for qid in pred_qid2data]
49
+ from functools import partial
50
+ compute_ap_from_triple = partial(
51
+ compute_average_precision_detection_wrapper, tiou_thresholds=iou_thds)
52
+
53
+ if num_workers > 1:
54
+ with mp.Pool(num_workers) as pool:
55
+ for qid, scores in pool.imap_unordered(compute_ap_from_triple, data_triples, chunksize=chunksize):
56
+ qid2ap_list[qid] = scores
57
+ else:
58
+ for data_triple in data_triples:
59
+ qid, scores = compute_ap_from_triple(data_triple)
60
+ qid2ap_list[qid] = scores
61
+
62
+ # print(f"compute_average_precision_detection {time.time() - start_time:.2f} seconds.")
63
+ ap_array = np.array(list(qid2ap_list.values())) # (#queries, #thd)
64
+ ap_thds = ap_array.mean(0) # mAP at different IoU thresholds.
65
+ iou_thd2ap = dict(zip([str(e) for e in iou_thds], ap_thds))
66
+ iou_thd2ap["average"] = np.mean(ap_thds)
67
+ # formatting
68
+ iou_thd2ap = {k: float(f"{100 * v:.2f}") for k, v in iou_thd2ap.items()}
69
+ return iou_thd2ap
70
+
71
+
72
+ def compute_mr_r1(submission, ground_truth, iou_thds=np.linspace(0.3, 0.95, 14)):
73
+ """If a predicted segment has IoU >= iou_thd with one of the 1st GT segment, we define it positive"""
74
+ iou_thds = [float(f"{e:.2f}") for e in iou_thds]
75
+ pred_qid2window = {d["qid"]: d["pred_relevant_windows"][0][:2] for d in submission} # :2 rm scores
76
+ # gt_qid2window = {d["qid"]: d["relevant_windows"][0] for d in ground_truth}
77
+ gt_qid2window = {}
78
+ for d in ground_truth:
79
+ cur_gt_windows = d["relevant_windows"]
80
+ cur_qid = d["qid"]
81
+ cur_max_iou_idx = 0
82
+ if len(cur_gt_windows) > 0: # select the GT window that has the highest IoU
83
+ cur_ious = compute_temporal_iou_batch_cross(
84
+ np.array([pred_qid2window[cur_qid]]), np.array(d["relevant_windows"])
85
+ )[0]
86
+ cur_max_iou_idx = np.argmax(cur_ious)
87
+ gt_qid2window[cur_qid] = cur_gt_windows[cur_max_iou_idx]
88
+
89
+ qids = list(pred_qid2window.keys())
90
+ pred_windows = np.array([pred_qid2window[k] for k in qids]).astype(float)
91
+ gt_windows = np.array([gt_qid2window[k] for k in qids]).astype(float)
92
+ pred_gt_iou = compute_temporal_iou_batch_paired(pred_windows, gt_windows)
93
+ iou_thd2recall_at_one = {}
94
+ miou_at_one = float(f"{np.mean(pred_gt_iou) * 100:.2f}")
95
+ for thd in iou_thds:
96
+ iou_thd2recall_at_one[str(thd)] = float(f"{np.mean(pred_gt_iou >= thd) * 100:.2f}")
97
+ return iou_thd2recall_at_one, miou_at_one
98
+
99
+
100
+ def get_window_len(window):
101
+ return window[1] - window[0]
102
+
103
+
104
+ def get_data_by_range(submission, ground_truth, len_range):
105
+ """ keep queries with ground truth window length in the specified length range.
106
+ Args:
107
+ submission:
108
+ ground_truth:
109
+ len_range: [min_l (int), max_l (int)]. the range is (min_l, max_l], i.e., min_l < l <= max_l
110
+ """
111
+ min_l, max_l = len_range
112
+ if min_l == 0 and max_l == 150: # min and max l in dataset
113
+ return submission, ground_truth
114
+
115
+ # only keep ground truth with windows in the specified length range
116
+ # if multiple GT windows exists, we only keep the ones in the range
117
+ ground_truth_in_range = []
118
+ gt_qids_in_range = set()
119
+ for d in ground_truth:
120
+ rel_windows_in_range = [
121
+ w for w in d["relevant_windows"] if min_l < get_window_len(w) <= max_l]
122
+ if len(rel_windows_in_range) > 0:
123
+ d = copy.deepcopy(d)
124
+ d["relevant_windows"] = rel_windows_in_range
125
+ ground_truth_in_range.append(d)
126
+ gt_qids_in_range.add(d["qid"])
127
+
128
+ # keep only submissions for ground_truth_in_range
129
+ submission_in_range = []
130
+ for d in submission:
131
+ if d["qid"] in gt_qids_in_range:
132
+ submission_in_range.append(copy.deepcopy(d))
133
+
134
+ return submission_in_range, ground_truth_in_range
135
+
136
+
137
+ def eval_moment_retrieval(submission, ground_truth, verbose=True):
138
+ length_ranges = [[0, 10], [10, 30], [30, 150], [0, 150], ] #
139
+ range_names = ["short", "middle", "long", "full"]
140
+
141
+ ret_metrics = {}
142
+ for l_range, name in zip(length_ranges, range_names):
143
+ if verbose:
144
+ start_time = time.time()
145
+ _submission, _ground_truth = get_data_by_range(submission, ground_truth, l_range)
146
+ print(f"{name}: {l_range}, {len(_ground_truth)}/{len(ground_truth)}="
147
+ f"{100*len(_ground_truth)/len(ground_truth):.2f} examples.")
148
+ if len(_ground_truth) == 0:
149
+ # ret_metrics[name] = {"MR-mAP": 0., "MR-R1": 0.}
150
+ dummy_dict = {}
151
+ for k in np.linspace(0.5, 0.95, 19):
152
+ dummy_dict[k] = 0.
153
+ dummy_dict['average'] = 0.
154
+ ret_metrics[name] = {"MR-mAP": dummy_dict, "MR-R1": dummy_dict}
155
+ else:
156
+ iou_thd2average_precision = compute_mr_ap(_submission, _ground_truth, num_workers=8, chunksize=50)
157
+ iou_thd2recall_at_one, miou_at_one = compute_mr_r1(_submission, _ground_truth)
158
+ ret_metrics[name] = {"MR-mIoU": miou_at_one,
159
+ "MR-mAP": iou_thd2average_precision,
160
+ "MR-R1": iou_thd2recall_at_one}
161
+
162
+ # iou_thd2average_precision = compute_mr_ap(_submission, _ground_truth, num_workers=8, chunksize=50)
163
+ # iou_thd2recall_at_one = compute_mr_r1(_submission, _ground_truth)
164
+ # ret_metrics[name] = {"MR-mAP": iou_thd2average_precision, "MR-R1": iou_thd2recall_at_one}
165
+ if verbose:
166
+ print(f"[eval_moment_retrieval] [{name}] {time.time() - start_time:.2f} seconds")
167
+ return ret_metrics
168
+
169
+
170
+ def compute_hl_hit1(qid2preds, qid2gt_scores_binary):
171
+ qid2max_scored_clip_idx = {k: np.argmax(v["pred_saliency_scores"]) for k, v in qid2preds.items()}
172
+ hit_scores = np.zeros((len(qid2preds), 3))
173
+ qids = list(qid2preds.keys())
174
+ for idx, qid in enumerate(qids):
175
+ pred_clip_idx = qid2max_scored_clip_idx[qid]
176
+ gt_scores_binary = qid2gt_scores_binary[qid] # (#clips, 3)
177
+ if pred_clip_idx < len(gt_scores_binary):
178
+ hit_scores[idx] = gt_scores_binary[pred_clip_idx]
179
+ # aggregate scores from 3 separate annotations (3 workers) by taking the max.
180
+ # then average scores from all queries.
181
+ hit_at_one = float(f"{100 * np.mean(np.max(hit_scores, 1)):.2f}")
182
+ return hit_at_one
183
+
184
+
185
+ def compute_hl_ap(qid2preds, qid2gt_scores_binary, num_workers=8, chunksize=50):
186
+ qid2pred_scores = {k: v["pred_saliency_scores"] for k, v in qid2preds.items()}
187
+ ap_scores = np.zeros((len(qid2preds), 3)) # (#preds, 3)
188
+ qids = list(qid2preds.keys())
189
+ input_tuples = []
190
+ for idx, qid in enumerate(qids):
191
+ for w_idx in range(3): # annotation score idx
192
+ y_true = qid2gt_scores_binary[qid][:, w_idx]
193
+ y_predict = np.array(qid2pred_scores[qid])
194
+ input_tuples.append((idx, w_idx, y_true, y_predict))
195
+
196
+ if num_workers > 1:
197
+ with mp.Pool(num_workers) as pool:
198
+ for idx, w_idx, score in pool.imap_unordered(
199
+ compute_ap_from_tuple, input_tuples, chunksize=chunksize):
200
+ ap_scores[idx, w_idx] = score
201
+ else:
202
+ for input_tuple in input_tuples:
203
+ idx, w_idx, score = compute_ap_from_tuple(input_tuple)
204
+ ap_scores[idx, w_idx] = score
205
+
206
+ # it's the same if we first average across different annotations, then average across queries
207
+ # since all queries have the same #annotations.
208
+ mean_ap = float(f"{100 * np.mean(ap_scores):.2f}")
209
+ return mean_ap
210
+
211
+
212
+ def compute_ap_from_tuple(input_tuple):
213
+ idx, w_idx, y_true, y_predict = input_tuple
214
+ if len(y_true) < len(y_predict):
215
+ # print(f"len(y_true) < len(y_predict) {len(y_true), len(y_predict)}")
216
+ y_predict = y_predict[:len(y_true)]
217
+ elif len(y_true) > len(y_predict):
218
+ # print(f"len(y_true) > len(y_predict) {len(y_true), len(y_predict)}")
219
+ _y_predict = np.zeros(len(y_true))
220
+ _y_predict[:len(y_predict)] = y_predict
221
+ y_predict = _y_predict
222
+
223
+ score = get_ap(y_true, y_predict)
224
+ return idx, w_idx, score
225
+
226
+
227
+ def mk_gt_scores(gt_data, clip_length=2):
228
+ """gt_data, dict, """
229
+ num_clips = int(gt_data["duration"] / clip_length)
230
+ saliency_scores_full_video = np.zeros((num_clips, 3))
231
+ relevant_clip_ids = np.array(gt_data["relevant_clip_ids"]) # (#relevant_clip_ids, )
232
+ saliency_scores_relevant_clips = np.array(gt_data["saliency_scores"]) # (#relevant_clip_ids, 3)
233
+ saliency_scores_full_video[relevant_clip_ids] = saliency_scores_relevant_clips
234
+ return saliency_scores_full_video # (#clips_in_video, 3) the scores are in range [0, 4]
235
+
236
+
237
+ def eval_highlight(submission, ground_truth, verbose=True):
238
+ """
239
+ Args:
240
+ submission:
241
+ ground_truth:
242
+ verbose:
243
+ """
244
+ qid2preds = {d["qid"]: d for d in submission}
245
+ qid2gt_scores_full_range = {d["qid"]: mk_gt_scores(d) for d in ground_truth} # scores in range [0, 4]
246
+ # gt_saliency_score_min: int, in [0, 1, 2, 3, 4]. The minimum score for a positive clip.
247
+ gt_saliency_score_min_list = [2, 3, 4]
248
+ saliency_score_names = ["Fair", "Good", "VeryGood"]
249
+ highlight_det_metrics = {}
250
+ for gt_saliency_score_min, score_name in zip(gt_saliency_score_min_list, saliency_score_names):
251
+ start_time = time.time()
252
+ qid2gt_scores_binary = {
253
+ k: (v >= gt_saliency_score_min).astype(float)
254
+ for k, v in qid2gt_scores_full_range.items()} # scores in [0, 1]
255
+ hit_at_one = compute_hl_hit1(qid2preds, qid2gt_scores_binary)
256
+ mean_ap = compute_hl_ap(qid2preds, qid2gt_scores_binary)
257
+ highlight_det_metrics[f"HL-min-{score_name}"] = {"HL-mAP": mean_ap, "HL-Hit1": hit_at_one}
258
+ if verbose:
259
+ print(f"Calculating highlight scores with min score {gt_saliency_score_min} ({score_name})")
260
+ print(f"Time cost {time.time() - start_time:.2f} seconds")
261
+ return highlight_det_metrics
262
+
263
+
264
+ def eval_submission(submission, ground_truth, verbose=True, match_number=False, hl=False):
265
+ """
266
+ Args:
267
+ submission: list(dict), each dict is {
268
+ qid: str,
269
+ query: str,
270
+ vid: str,
271
+ pred_relevant_windows: list([st, ed]),
272
+ pred_saliency_scores: list(float), len == #clips in video.
273
+ i.e., each clip in the video will have a saliency score.
274
+ }
275
+ ground_truth: list(dict), each dict is {
276
+ "qid": 7803,
277
+ "query": "Man in gray top walks from outside to inside.",
278
+ "duration": 150,
279
+ "vid": "RoripwjYFp8_360.0_510.0",
280
+ "relevant_clip_ids": [13, 14, 15, 16, 17]
281
+ "saliency_scores": [[4, 4, 2], [3, 4, 2], [2, 2, 3], [2, 2, 2], [0, 1, 3]]
282
+ each sublist corresponds to one clip in relevant_clip_ids.
283
+ The 3 elements in the sublist are scores from 3 different workers. The
284
+ scores are in [0, 1, 2, 3, 4], meaning [Very Bad, ..., Good, Very Good]
285
+ }
286
+ verbose:
287
+ match_number:
288
+
289
+ Returns:
290
+
291
+ """
292
+ pred_qids = set([e["qid"] for e in submission])
293
+ gt_qids = set([e["qid"] for e in ground_truth])
294
+ # import pdb; pdb.set_trace()
295
+ if match_number:
296
+ assert pred_qids == gt_qids, \
297
+ f"qids in ground_truth and submission must match. " \
298
+ f"use `match_number=False` if you wish to disable this check"
299
+ else: # only leave the items that exists in both submission and ground_truth
300
+ shared_qids = pred_qids.intersection(gt_qids)
301
+ submission = [e for e in submission if e["qid"] in shared_qids]
302
+ ground_truth = [e for e in ground_truth if e["qid"] in shared_qids]
303
+
304
+ eval_metrics = {}
305
+ eval_metrics_brief = OrderedDict()
306
+ if "pred_relevant_windows" in submission[0]:
307
+ moment_ret_scores = eval_moment_retrieval(submission, ground_truth, verbose=verbose)
308
+ eval_metrics.update(moment_ret_scores)
309
+ moment_ret_scores_brief = {
310
+ "MR-full-mAP": moment_ret_scores["full"]["MR-mAP"]["average"],
311
+ "[email protected]": moment_ret_scores["full"]["MR-mAP"]["0.5"],
312
+ "[email protected]": moment_ret_scores["full"]["MR-mAP"]["0.75"],
313
+ "MR-short-mAP": moment_ret_scores["short"]["MR-mAP"]["average"],
314
+ "MR-middle-mAP": moment_ret_scores["middle"]["MR-mAP"]["average"],
315
+ "MR-long-mAP": moment_ret_scores["long"]["MR-mAP"]["average"],
316
+ "MR-full-mIoU": moment_ret_scores["full"]["MR-mIoU"],
317
+ "[email protected]": moment_ret_scores["full"]["MR-R1"]["0.3"],
318
+ "[email protected]": moment_ret_scores["full"]["MR-R1"]["0.5"],
319
+ "[email protected]": moment_ret_scores["full"]["MR-R1"]["0.7"],
320
+ }
321
+ eval_metrics_brief.update(
322
+ sorted([(k, v) for k, v in moment_ret_scores_brief.items()], key=lambda x: x[0]))
323
+
324
+ if "pred_saliency_scores" in submission[0] and hl:
325
+ highlight_det_scores = eval_highlight(
326
+ submission, ground_truth, verbose=verbose)
327
+ eval_metrics.update(highlight_det_scores)
328
+ highlight_det_scores_brief = dict([
329
+ (f"{k}-{sub_k.split('-')[1]}", v[sub_k])
330
+ for k, v in highlight_det_scores.items() for sub_k in v])
331
+ eval_metrics_brief.update(highlight_det_scores_brief)
332
+
333
+ # sort by keys
334
+ final_eval_metrics = OrderedDict()
335
+ final_eval_metrics["brief"] = eval_metrics_brief
336
+ final_eval_metrics.update(sorted([(k, v) for k, v in eval_metrics.items()], key=lambda x: x[0]))
337
+ return final_eval_metrics
338
+
339
+
340
+ def eval_main():
341
+ import argparse
342
+ parser = argparse.ArgumentParser(description="Moments and Highlights Evaluation Script")
343
+ parser.add_argument("--submission_path", type=str, help="path to generated prediction file")
344
+ parser.add_argument("--gt_path", type=str, help="path to GT file")
345
+ parser.add_argument("--save_path", type=str, help="path to save the results")
346
+ parser.add_argument("--not_verbose", action="store_true")
347
+ args = parser.parse_args()
348
+
349
+ verbose = not args.not_verbose
350
+ submission = load_jsonl(args.submission_path)
351
+ gt = load_jsonl(args.gt_path)
352
+ results = eval_submission(submission, gt, verbose=verbose)
353
+ if verbose:
354
+ print(json.dumps(results, indent=4))
355
+
356
+ with open(args.save_path, "w") as f:
357
+ f.write(json.dumps(results, indent=4))
358
+
359
+
360
+ if __name__ == '__main__':
361
+ eval_main()
third_party/cgdetr/standalone_eval/eval_sample.sh ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # Usage: bash standalone_eval/eval_sample.sh
3
+ submission_path=standalone_eval/sample_val_preds.jsonl
4
+ gt_path=data/highlight_val_release.jsonl
5
+ save_path=standalone_eval/sample_val_preds_metrics.json
6
+
7
+ PYTHONPATH=$PYTHONPATH:. python standalone_eval/eval.py \
8
+ --submission_path ${submission_path} \
9
+ --gt_path ${gt_path} \
10
+ --save_path ${save_path}
third_party/cgdetr/standalone_eval/utils.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copied from MMAction2
3
+ https://github.com/open-mmlab/mmaction2/blob/master/mmaction/core/evaluation/eval_detection.py
4
+ """
5
+ import json
6
+ import numpy as np
7
+ from sklearn.metrics import precision_recall_curve
8
+
9
+
10
+ def load_jsonl(filename):
11
+ with open(filename, "r") as f:
12
+ return [json.loads(l.strip("\n")) for l in f.readlines()]
13
+
14
+
15
+ def compute_temporal_iou_batch_paired(pred_windows, gt_windows):
16
+ """ compute intersection-over-union along temporal axis for each pair of windows in pred_windows and gt_windows.
17
+ Args:
18
+ pred_windows: np.ndarray, (N, 2), [st (float), ed (float)] * N
19
+ gt_windows: np.ndarray, (N, 2), [st (float), ed (float)] * N
20
+ Returns:
21
+ iou (float): np.ndarray, (N, )
22
+
23
+ References:
24
+ for np.divide with zeros, see https://stackoverflow.com/a/37977222
25
+ """
26
+ intersection = np.maximum(
27
+ 0, np.minimum(pred_windows[:, 1], gt_windows[:, 1]) - np.maximum(pred_windows[:, 0], gt_windows[:, 0])
28
+ )
29
+ union = np.maximum(pred_windows[:, 1], gt_windows[:, 1]) \
30
+ - np.minimum(pred_windows[:, 0], gt_windows[:, 0]) # not the correct union though
31
+ return np.divide(intersection, union, out=np.zeros_like(intersection), where=union != 0)
32
+
33
+
34
+ def compute_temporal_iou_batch_cross(spans1, spans2):
35
+ """
36
+ Args:
37
+ spans1: (N, 2) np.ndarray, each row defines a span [st, ed]
38
+ spans2: (M, 2) np.ndarray, ...
39
+
40
+ Returns:
41
+ iou: (N, M) np.ndarray
42
+ union: (N, M) np.ndarray
43
+ >>> spans1 = np.array([[0, 0.2, 0.9], [0.5, 1.0, 0.2]])
44
+ >>> spans2 = np.array([[0, 0.3], [0., 1.0]])
45
+ >>> compute_temporal_iou_batch_cross(spans1, spans2)
46
+ (tensor([[0.6667, 0.2000],
47
+ [0.0000, 0.5000]]),
48
+ tensor([[0.3000, 1.0000],
49
+ [0.8000, 1.0000]]))
50
+ """
51
+ areas1 = spans1[:, 1] - spans1[:, 0] # (N, )
52
+ areas2 = spans2[:, 1] - spans2[:, 0] # (M, )
53
+
54
+ left = np.maximum(spans1[:, None, 0], spans2[None, :, 0]) # (N, M)
55
+ right = np.minimum(spans1[:, None, 1], spans2[None, :, 1]) # (N, M)
56
+
57
+ inter = np.clip(right - left, 0, None) # (N, M)
58
+ union = areas1[:, None] + areas2[None, :] - inter # (N, M)
59
+
60
+ iou = inter / union
61
+ return iou, union
62
+
63
+
64
+ def interpolated_precision_recall(precision, recall):
65
+ """Interpolated AP - VOCdevkit from VOC 2011.
66
+
67
+ Args:
68
+ precision (np.ndarray): The precision of different thresholds.
69
+ recall (np.ndarray): The recall of different thresholds.
70
+
71
+ Returns:
72
+ float: Average precision score.
73
+ """
74
+ mprecision = np.hstack([[0], precision, [0]])
75
+ mrecall = np.hstack([[0], recall, [1]])
76
+ for i in range(len(mprecision) - 1)[::-1]:
77
+ mprecision[i] = max(mprecision[i], mprecision[i + 1])
78
+ idx = np.where(mrecall[1::] != mrecall[0:-1])[0] + 1
79
+ ap = np.sum((mrecall[idx] - mrecall[idx - 1]) * mprecision[idx])
80
+ return ap
81
+
82
+
83
+ def compute_average_precision_detection(ground_truth,
84
+ prediction,
85
+ tiou_thresholds=np.linspace(
86
+ 0.5, 0.95, 10)):
87
+ """Compute average precision (detection task) between ground truth and
88
+ predictions data frames. If multiple predictions occurs for the same
89
+ predicted segment, only the one with highest score is matches as true
90
+ positive. This code is greatly inspired by Pascal VOC devkit.
91
+
92
+ Args:
93
+ ground_truth (list[dict]): List containing the ground truth instances
94
+ (dictionaries). Required keys are 'video-id', 't-start' and
95
+ 't-end'.
96
+ prediction (list[dict]): List containing the prediction instances
97
+ (dictionaries). Required keys are: 'video-id', 't-start', 't-end'
98
+ and 'score'.
99
+ tiou_thresholds (np.ndarray): A 1darray indicates the temporal
100
+ intersection over union threshold, which is optional.
101
+ Default: ``np.linspace(0.5, 0.95, 10)``.
102
+
103
+ Returns:
104
+ Float: ap, Average precision score.
105
+ """
106
+ num_thresholds = len(tiou_thresholds)
107
+ num_gts = len(ground_truth)
108
+ num_preds = len(prediction)
109
+ ap = np.zeros(num_thresholds)
110
+ if len(prediction) == 0:
111
+ return ap
112
+
113
+ num_positive = float(num_gts)
114
+ lock_gt = np.ones((num_thresholds, num_gts)) * -1
115
+ # Sort predictions by decreasing score order.
116
+ prediction.sort(key=lambda x: -x['score'])
117
+ # Initialize true positive and false positive vectors.
118
+ tp = np.zeros((num_thresholds, num_preds))
119
+ fp = np.zeros((num_thresholds, num_preds))
120
+
121
+ # Adaptation to query faster
122
+ ground_truth_by_videoid = {}
123
+ for i, item in enumerate(ground_truth):
124
+ item['index'] = i
125
+ ground_truth_by_videoid.setdefault(item['video-id'], []).append(item)
126
+
127
+ # Assigning true positive to truly grount truth instances.
128
+ for idx, pred in enumerate(prediction):
129
+ if pred['video-id'] in ground_truth_by_videoid:
130
+ gts = ground_truth_by_videoid[pred['video-id']]
131
+ else:
132
+ fp[:, idx] = 1
133
+ continue
134
+
135
+ _pred = np.array([[pred['t-start'], pred['t-end']], ])
136
+ _gt = np.array([[gt['t-start'], gt['t-end']] for gt in gts])
137
+ tiou_arr = compute_temporal_iou_batch_cross(_pred, _gt)[0]
138
+
139
+ tiou_arr = tiou_arr.reshape(-1)
140
+ # We would like to retrieve the predictions with highest tiou score.
141
+ tiou_sorted_idx = tiou_arr.argsort()[::-1]
142
+ for t_idx, tiou_threshold in enumerate(tiou_thresholds):
143
+ for j_idx in tiou_sorted_idx:
144
+ if tiou_arr[j_idx] < tiou_threshold:
145
+ fp[t_idx, idx] = 1
146
+ break
147
+ if lock_gt[t_idx, gts[j_idx]['index']] >= 0:
148
+ continue
149
+ # Assign as true positive after the filters above.
150
+ tp[t_idx, idx] = 1
151
+ lock_gt[t_idx, gts[j_idx]['index']] = idx
152
+ break
153
+
154
+ if fp[t_idx, idx] == 0 and tp[t_idx, idx] == 0:
155
+ fp[t_idx, idx] = 1
156
+
157
+ tp_cumsum = np.cumsum(tp, axis=1).astype(float)
158
+ fp_cumsum = np.cumsum(fp, axis=1).astype(float)
159
+ recall_cumsum = tp_cumsum / num_positive
160
+
161
+ precision_cumsum = tp_cumsum / (tp_cumsum + fp_cumsum)
162
+
163
+ for t_idx in range(len(tiou_thresholds)):
164
+ ap[t_idx] = interpolated_precision_recall(precision_cumsum[t_idx, :],
165
+ recall_cumsum[t_idx, :])
166
+ return ap
167
+
168
+
169
+ def get_ap(y_true, y_predict, interpolate=True, point_11=False):
170
+ """
171
+ Average precision in different formats: (non-) interpolated and/or 11-point approximated
172
+ point_11=True and interpolate=True corresponds to the 11-point interpolated AP used in
173
+ the PASCAL VOC challenge up to the 2008 edition and has been verfied against the vlfeat implementation
174
+ The exact average precision (interpolate=False, point_11=False) corresponds to the one of vl_feat
175
+
176
+ :param y_true: list/ numpy vector of true labels in {0,1} for each element
177
+ :param y_predict: predicted score for each element
178
+ :param interpolate: Use interpolation?
179
+ :param point_11: Use 11-point approximation to average precision?
180
+ :return: average precision
181
+
182
+ ref: https://github.com/gyglim/video2gif_dataset/blob/master/v2g_evaluation/__init__.py
183
+
184
+ """
185
+ # Check inputs
186
+ assert len(y_true) == len(y_predict), "Prediction and ground truth need to be of the same length"
187
+ if len(set(y_true)) == 1:
188
+ if y_true[0] == 0:
189
+ return 0 # True labels are all zeros
190
+ # raise ValueError('True labels cannot all be zero')
191
+ else:
192
+ return 1
193
+ else:
194
+ assert sorted(set(y_true)) == [0, 1], "Ground truth can only contain elements {0,1}"
195
+
196
+ # Compute precision and recall
197
+ precision, recall, _ = precision_recall_curve(y_true, y_predict)
198
+ recall = recall.astype(np.float32)
199
+
200
+ if interpolate: # Compute the interpolated precision
201
+ for i in range(1, len(precision)):
202
+ precision[i] = max(precision[i - 1], precision[i])
203
+
204
+ if point_11: # Compute the 11-point approximated AP
205
+ precision_11 = [precision[np.where(recall >= t)[0][-1]] for t in np.arange(0, 1.01, 0.1)]
206
+ return np.mean(precision_11)
207
+ else: # Compute the AP using precision at every additionally recalled sample
208
+ indices = np.where(np.diff(recall))
209
+ return np.mean(precision[indices])
third_party/cgdetr/utils/basic_utils.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import zipfile
4
+ import numpy as np
5
+ import pickle
6
+ from collections import OrderedDict, Counter
7
+ import pandas as pd
8
+
9
+
10
+ def load_pickle(filename):
11
+ with open(filename, "rb") as f:
12
+ return pickle.load(f)
13
+
14
+
15
+ def save_pickle(data, filename):
16
+ with open(filename, "wb") as f:
17
+ pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL)
18
+
19
+
20
+ def load_json(filename):
21
+ with open(filename, "r") as f:
22
+ return json.load(f)
23
+
24
+
25
+ def save_json(data, filename, save_pretty=False, sort_keys=False):
26
+ with open(filename, "w") as f:
27
+ if save_pretty:
28
+ f.write(json.dumps(data, indent=4, sort_keys=sort_keys))
29
+ else:
30
+ json.dump(data, f)
31
+
32
+
33
+ def load_jsonl(filename):
34
+ with open(filename, "r") as f:
35
+ return [json.loads(l.strip("\n")) for l in f.readlines()]
36
+
37
+
38
+ def save_jsonl(data, filename):
39
+ """data is a list"""
40
+ with open(filename, "w") as f:
41
+ f.write("\n".join([json.dumps(e) for e in data]))
42
+
43
+
44
+ def save_lines(list_of_str, filepath):
45
+ with open(filepath, "w") as f:
46
+ f.write("\n".join(list_of_str))
47
+
48
+
49
+ def read_lines(filepath):
50
+ with open(filepath, "r") as f:
51
+ return [e.strip("\n") for e in f.readlines()]
52
+
53
+
54
+ def mkdirp(p):
55
+ if not os.path.exists(p):
56
+ os.makedirs(p)
57
+
58
+
59
+ def flat_list_of_lists(l):
60
+ """flatten a list of lists [[1,2], [3,4]] to [1,2,3,4]"""
61
+ return [item for sublist in l for item in sublist]
62
+
63
+
64
+ def convert_to_seconds(hms_time):
65
+ """ convert '00:01:12' to 72 seconds.
66
+ :hms_time (str): time in comma separated string, e.g. '00:01:12'
67
+ :return (int): time in seconds, e.g. 72
68
+ """
69
+ times = [float(t) for t in hms_time.split(":")]
70
+ return times[0] * 3600 + times[1] * 60 + times[2]
71
+
72
+
73
+ def get_video_name_from_url(url):
74
+ return url.split("/")[-1][:-4]
75
+
76
+
77
+ def merge_dicts(list_dicts):
78
+ merged_dict = list_dicts[0].copy()
79
+ for i in range(1, len(list_dicts)):
80
+ merged_dict.update(list_dicts[i])
81
+ return merged_dict
82
+
83
+
84
+ def l2_normalize_np_array(np_array, eps=1e-5):
85
+ """np_array: np.ndarray, (*, D), where the last dim will be normalized"""
86
+ return np_array / (np.linalg.norm(np_array, axis=-1, keepdims=True) + eps)
87
+
88
+
89
+ def make_zipfile(src_dir, save_path, enclosing_dir="", exclude_dirs=None, exclude_extensions=None,
90
+ exclude_dirs_substring=None):
91
+ """make a zip file of root_dir, save it to save_path.
92
+ exclude_paths will be excluded if it is a subdir of root_dir.
93
+ An enclosing_dir is added is specified.
94
+ """
95
+ abs_src = os.path.abspath(src_dir)
96
+ with zipfile.ZipFile(save_path, "w") as zf:
97
+ for dirname, subdirs, files in os.walk(src_dir):
98
+ if exclude_dirs is not None:
99
+ for e_p in exclude_dirs:
100
+ if e_p in subdirs:
101
+ subdirs.remove(e_p)
102
+ if exclude_dirs_substring is not None:
103
+ to_rm = []
104
+ for d in subdirs:
105
+ if exclude_dirs_substring in d:
106
+ to_rm.append(d)
107
+ for e in to_rm:
108
+ subdirs.remove(e)
109
+ arcname = os.path.join(enclosing_dir, dirname[len(abs_src) + 1:])
110
+ zf.write(dirname, arcname)
111
+ for filename in files:
112
+ if exclude_extensions is not None:
113
+ if os.path.splitext(filename)[1] in exclude_extensions:
114
+ continue # do not zip it
115
+ absname = os.path.join(dirname, filename)
116
+ arcname = os.path.join(enclosing_dir, absname[len(abs_src) + 1:])
117
+ zf.write(absname, arcname)
118
+
119
+
120
+ class AverageMeter(object):
121
+ """Computes and stores the average and current/max/min value"""
122
+ def __init__(self):
123
+ self.val = 0
124
+ self.avg = 0
125
+ self.sum = 0
126
+ self.count = 0
127
+ self.max = -1e10
128
+ self.min = 1e10
129
+ self.reset()
130
+
131
+ def reset(self):
132
+ self.val = 0
133
+ self.avg = 0
134
+ self.sum = 0
135
+ self.count = 0
136
+ self.max = -1e10
137
+ self.min = 1e10
138
+
139
+ def update(self, val, n=1):
140
+ self.max = max(val, self.max)
141
+ self.min = min(val, self.min)
142
+ self.val = val
143
+ self.sum += val * n
144
+ self.count += n
145
+ self.avg = self.sum / self.count
146
+
147
+
148
+ def dissect_by_lengths(np_array, lengths, dim=0, assert_equal=True):
149
+ """Dissect an array (N, D) into a list a sub-array,
150
+ np_array.shape[0] == sum(lengths), Output is a list of nd arrays, singlton dimention is kept"""
151
+ if assert_equal:
152
+ assert len(np_array) == sum(lengths)
153
+ length_indices = [0, ]
154
+ for i in range(len(lengths)):
155
+ length_indices.append(length_indices[i] + lengths[i])
156
+ if dim == 0:
157
+ array_list = [np_array[length_indices[i]:length_indices[i+1]] for i in range(len(lengths))]
158
+ elif dim == 1:
159
+ array_list = [np_array[:, length_indices[i]:length_indices[i + 1]] for i in range(len(lengths))]
160
+ elif dim == 2:
161
+ array_list = [np_array[:, :, length_indices[i]:length_indices[i + 1]] for i in range(len(lengths))]
162
+ else:
163
+ raise NotImplementedError
164
+ return array_list
165
+
166
+
167
+ def get_ratio_from_counter(counter_obj, threshold=200):
168
+ keys = counter_obj.keys()
169
+ values = counter_obj.values()
170
+ filtered_values = [counter_obj[k] for k in keys if k > threshold]
171
+ return float(sum(filtered_values)) / sum(values)
172
+
173
+
174
+ def get_counter_dist(counter_object, sort_type="none"):
175
+ _sum = sum(counter_object.values())
176
+ dist = {k: float(f"{100 * v / _sum:.2f}") for k, v in counter_object.items()}
177
+ if sort_type == "value":
178
+ dist = OrderedDict(sorted(dist.items(), reverse=True))
179
+ return dist
180
+
181
+
182
+ def get_show_name(vid_name):
183
+ """
184
+ get tvshow name from vid_name
185
+ :param vid_name: video clip name
186
+ :return: tvshow name
187
+ """
188
+ show_list = ["friends", "met", "castle", "house", "grey"]
189
+ vid_name_prefix = vid_name.split("_")[0]
190
+ show_name = vid_name_prefix if vid_name_prefix in show_list else "bbt"
191
+ return show_name
192
+
193
+
194
+ def get_abspaths_by_ext(dir_path, ext=(".jpg",)):
195
+ """Get absolute paths to files in dir_path with extensions specified by ext.
196
+ Note this function does work recursively.
197
+ """
198
+ if isinstance(ext, list):
199
+ ext = tuple(ext)
200
+ if isinstance(ext, str):
201
+ ext = tuple([ext, ])
202
+ filepaths = [os.path.join(root, name)
203
+ for root, dirs, files in os.walk(dir_path)
204
+ for name in files
205
+ if name.endswith(tuple(ext))]
206
+ return filepaths
207
+
208
+
209
+ def get_basename_no_ext(path):
210
+ """ '/data/movienet/240p_keyframe_feats/tt7672188.npz' --> 'tt7672188' """
211
+ return os.path.splitext(os.path.split(path)[1])[0]
212
+
213
+
214
+ def dict_to_markdown(d, max_str_len=120):
215
+ # convert list into its str representation
216
+ d = {k: v.__repr__() if isinstance(v, list) else v for k, v in d.items()}
217
+ # truncate string that is longer than max_str_len
218
+ if max_str_len is not None:
219
+ d = {k: v[-max_str_len:] if isinstance(v, str) else v for k, v in d.items()}
220
+ return pd.DataFrame(d, index=[0]).transpose().to_markdown()
221
+