ynhe
commited on
Commit
·
16dc4f2
1
Parent(s):
371f0d2
init
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- added_tokens.json +9 -0
- config.json +60 -0
- model-00001-of-00004.safetensors +3 -0
- model-00002-of-00004.safetensors +3 -0
- model-00003-of-00004.safetensors +3 -0
- model-00004-of-00004.safetensors +3 -0
- model.safetensors.index.json +0 -0
- model_config.py +24 -0
- modeling_base.py +387 -0
- modeling_qformer.py +1264 -0
- modeling_special_token.py +27 -0
- modeling_videochate.py +681 -0
- modeling_vit.py +487 -0
- special_tokens_map.json +24 -0
- third_party/__init__.py +2 -0
- third_party/cgdetr/cg_detr/__init__.py +0 -0
- third_party/cgdetr/cg_detr/__pycache__/__init__.cpython-310.pyc +0 -0
- third_party/cgdetr/cg_detr/__pycache__/attention.cpython-310.pyc +0 -0
- third_party/cgdetr/cg_detr/__pycache__/crossattention.cpython-310.pyc +0 -0
- third_party/cgdetr/cg_detr/__pycache__/matcher.cpython-310.pyc +0 -0
- third_party/cgdetr/cg_detr/__pycache__/misc.cpython-310.pyc +0 -0
- third_party/cgdetr/cg_detr/__pycache__/model.cpython-310.pyc +0 -0
- third_party/cgdetr/cg_detr/__pycache__/position_encoding.cpython-310.pyc +0 -0
- third_party/cgdetr/cg_detr/__pycache__/span_utils.cpython-310.pyc +0 -0
- third_party/cgdetr/cg_detr/__pycache__/transformer.cpython-310.pyc +0 -0
- third_party/cgdetr/cg_detr/attention.py +394 -0
- third_party/cgdetr/cg_detr/config.py +261 -0
- third_party/cgdetr/cg_detr/crossattention.py +396 -0
- third_party/cgdetr/cg_detr/inference.py +480 -0
- third_party/cgdetr/cg_detr/matcher.py +109 -0
- third_party/cgdetr/cg_detr/misc.py +21 -0
- third_party/cgdetr/cg_detr/model.py +1178 -0
- third_party/cgdetr/cg_detr/position_encoding.py +116 -0
- third_party/cgdetr/cg_detr/postprocessing_cg_detr.py +95 -0
- third_party/cgdetr/cg_detr/scripts/charades_sta/inference.sh +8 -0
- third_party/cgdetr/cg_detr/scripts/charades_sta/train.sh +95 -0
- third_party/cgdetr/cg_detr/scripts/inference.sh +11 -0
- third_party/cgdetr/cg_detr/scripts/train.sh +76 -0
- third_party/cgdetr/cg_detr/span_utils.py +127 -0
- third_party/cgdetr/cg_detr/start_end_dataset.py +383 -0
- third_party/cgdetr/cg_detr/text_encoder.py +53 -0
- third_party/cgdetr/cg_detr/train.py +283 -0
- third_party/cgdetr/cg_detr/transformer.py +871 -0
- third_party/cgdetr/data/LICENSE +437 -0
- third_party/cgdetr/data/README.md +24 -0
- third_party/cgdetr/standalone_eval/README.md +54 -0
- third_party/cgdetr/standalone_eval/eval.py +361 -0
- third_party/cgdetr/standalone_eval/eval_sample.sh +10 -0
- third_party/cgdetr/standalone_eval/utils.py +209 -0
- third_party/cgdetr/utils/basic_utils.py +221 -0
added_tokens.json
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"<box_begin>": 32000,
|
3 |
+
"<boxes>": 32003,
|
4 |
+
"<temp>": 32002,
|
5 |
+
"<time_begin>": 32001,
|
6 |
+
"<track_begin>": 32004,
|
7 |
+
"<track_box>": 32006,
|
8 |
+
"<tracking>": 32005
|
9 |
+
}
|
config.json
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "",
|
3 |
+
"architectures": [
|
4 |
+
"MultiModalLLM_PT"
|
5 |
+
],
|
6 |
+
"auto_map": {
|
7 |
+
"AutoConfig": "model_config.VideoChatEConfig",
|
8 |
+
"AutoModel": "modeling_videochate.MultiModalLLM_PT"
|
9 |
+
},
|
10 |
+
"model_config": {
|
11 |
+
"bridge": {
|
12 |
+
"extra_num_query_token": 64,
|
13 |
+
"name": "qformer",
|
14 |
+
"num_query_token": 32,
|
15 |
+
"qformer_attention_probs_dropout_prob": 0.1,
|
16 |
+
"qformer_drop_path_rate": 0.2,
|
17 |
+
"qformer_hidden_dropout_prob": 0.1
|
18 |
+
},
|
19 |
+
"freeze_bridge": false,
|
20 |
+
"freeze_llm": false,
|
21 |
+
"freeze_vision_encoder": false,
|
22 |
+
"llm": {
|
23 |
+
"lora_alpha": 32,
|
24 |
+
"lora_dropout": 0.1,
|
25 |
+
"lora_r": 16,
|
26 |
+
"name": "mistral_7b",
|
27 |
+
"pretrained_llm_path": "mistralai/Mistral-7B-Instruct-v0.3",
|
28 |
+
"use_lora": true,
|
29 |
+
"hidden_size": 4096
|
30 |
+
},
|
31 |
+
"loss": {
|
32 |
+
"use_vision_regression_loss": false
|
33 |
+
},
|
34 |
+
"pretrained_paths": {},
|
35 |
+
|
36 |
+
"vision_encoder": {
|
37 |
+
"name":"vit_l14",
|
38 |
+
"img_size":224,
|
39 |
+
"patch_size":16,
|
40 |
+
"d_model":1024,
|
41 |
+
"encoder_embed_dim":1024,
|
42 |
+
"encoder_depth":24,
|
43 |
+
"encoder_num_heads":16,
|
44 |
+
"drop_path_rate": 0.0,
|
45 |
+
"num_frames":16,
|
46 |
+
"tubelet_size":1,
|
47 |
+
"use_checkpoint":false,
|
48 |
+
"checkpoint_num":0,
|
49 |
+
"return_index":-2,
|
50 |
+
"vit_add_ln":true,
|
51 |
+
"pretrained": null
|
52 |
+
}
|
53 |
+
},
|
54 |
+
"torch_dtype": "float32",
|
55 |
+
"transformers_version": "4.38.0",
|
56 |
+
"use_flash_attention": true,
|
57 |
+
"use_cache": true,
|
58 |
+
"build_decoder":true,
|
59 |
+
"hidden_size": 4096
|
60 |
+
}
|
model-00001-of-00004.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c8640081ef34803134daaa9c4a69693b680bce89b3dcc66ed22df3a26d97a330
|
3 |
+
size 4995778624
|
model-00002-of-00004.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8907fc1447bbaac12ba0e687cd25a19810d2717fe7e73dad57f70f13532c086c
|
3 |
+
size 4945367960
|
model-00003-of-00004.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5640bcb11fc22f6a7604176b29807f56a856126ed62a97a9e99585eb313e4f93
|
3 |
+
size 4945392936
|
model-00004-of-00004.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d0360c7932a26ce651a78f5dc26e7b6735ff0bc019fb8adce231fea4375d4fe8
|
3 |
+
size 1316534788
|
model.safetensors.index.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
model_config.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import re, ast
|
3 |
+
from transformers import AutoConfig, LlamaConfig
|
4 |
+
from transformers.configuration_utils import PretrainedConfig
|
5 |
+
from transformers.utils import logging
|
6 |
+
|
7 |
+
from easydict import EasyDict as MyEasyDict
|
8 |
+
from importlib import import_module
|
9 |
+
import os.path as osp
|
10 |
+
import argparse
|
11 |
+
import json
|
12 |
+
from copy import deepcopy
|
13 |
+
import sys
|
14 |
+
|
15 |
+
|
16 |
+
class VideoChatEConfig(PretrainedConfig):
|
17 |
+
model_type = 'VideoChatE'
|
18 |
+
|
19 |
+
def __init__(
|
20 |
+
self,
|
21 |
+
model_config=None,
|
22 |
+
**kwargs):
|
23 |
+
super().__init__(**kwargs)
|
24 |
+
self.model_config = MyEasyDict(model_config)
|
modeling_base.py
ADDED
@@ -0,0 +1,387 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io
|
2 |
+
import logging
|
3 |
+
import torch
|
4 |
+
import torch.utils.checkpoint
|
5 |
+
from torch import nn
|
6 |
+
from torch.nn import MSELoss
|
7 |
+
from transformers.modeling_outputs import (
|
8 |
+
CausalLMOutputWithPast,
|
9 |
+
)
|
10 |
+
from typing import List, Optional, Tuple, Union
|
11 |
+
from transformers import LlamaForCausalLM
|
12 |
+
|
13 |
+
from torch.cuda.amp import autocast as autocast
|
14 |
+
|
15 |
+
from .modeling_vit import build_vit
|
16 |
+
from .modeling_qformer import build_qformer
|
17 |
+
from .model_config import VideoChatEConfig
|
18 |
+
logger = logging.getLogger(__name__)
|
19 |
+
|
20 |
+
from transformers import LlamaTokenizer,AutoTokenizer,AutoModel,AutoModelForCausalLM,AutoProcessor
|
21 |
+
from transformers import AutoConfig, PreTrainedModel
|
22 |
+
|
23 |
+
import os
|
24 |
+
import sys
|
25 |
+
|
26 |
+
|
27 |
+
try:
|
28 |
+
from third_party.sam2.build_sam import build_sam2_video_predictor
|
29 |
+
from third_party.cgdetr.cg_detr.model import build_cgdetr_model
|
30 |
+
except:
|
31 |
+
print("can not import sam2 and cg-detr, install them first.")
|
32 |
+
|
33 |
+
DEFAULT_IMG_TOKEN = "[IMG]"
|
34 |
+
DEFAULT_IMG_END_TOKEN = "[/IMG]"
|
35 |
+
|
36 |
+
DEFAULT_IMAGE_TOKEN = "<image>"
|
37 |
+
DEFAULT_VIDEO_TOKEN = "[VIDEO]"
|
38 |
+
|
39 |
+
IMG_TOKEN = "[<IMG_PLH>]"
|
40 |
+
VID_TOKEN = "[<VID_PLH>]"
|
41 |
+
|
42 |
+
BOX_START = '<box_begin>'
|
43 |
+
# BOX_END = '<box_end>'
|
44 |
+
ATBOXES_PLACEHOLDER = '<box_begin><boxes>'
|
45 |
+
# ATBOXES_PLACEHOLDER = '<box_begin>'
|
46 |
+
BOXES_PLACEHOLDER = '<boxes>'
|
47 |
+
EXPR_PLACEHOLDER = '<expr>'
|
48 |
+
QUESTION_PLACEHOLDER = '<question>'
|
49 |
+
TIME_START = '<time_begin>'
|
50 |
+
# TIME_END = '<time_end>'
|
51 |
+
TIME_PLACEHOLDER = '<temp>'
|
52 |
+
ATTEMP_PLACEHOLDER = TIME_START + TIME_PLACEHOLDER
|
53 |
+
# ATTEMP_PLACEHOLDER = TIME_START
|
54 |
+
TRACK_START='<track_begin>'
|
55 |
+
TRACK_PLACEHOLDER = '<tracking>'
|
56 |
+
TRACK_START_BOX = '<track_box>'
|
57 |
+
ATTRACK_PLACEHOLDER = TRACK_START + TRACK_PLACEHOLDER
|
58 |
+
need_template_list = ['REC', 'flickr', 'tracking', 'tracking2', 'tracking3', 'tracking4']
|
59 |
+
|
60 |
+
load_image_list = ['image', 'REC', 'flickr']
|
61 |
+
load_video_list = ['video', 'TVG', 'tracking', 'tracking2','tracking3', 'tracking4', 'TVG+HL']
|
62 |
+
special_tokens = [BOX_START, TIME_START, TIME_PLACEHOLDER, BOXES_PLACEHOLDER, TRACK_START, TRACK_PLACEHOLDER, TRACK_START_BOX]
|
63 |
+
|
64 |
+
def disabled_train(self, mode=True):
|
65 |
+
"""Overwrite model.train with this function to make sure train/eval mode
|
66 |
+
does not change anymore."""
|
67 |
+
return self
|
68 |
+
|
69 |
+
|
70 |
+
def freeze_module(module):
|
71 |
+
for _, param in module.named_parameters():
|
72 |
+
param.requires_grad = False
|
73 |
+
module = module.eval()
|
74 |
+
module.train = disabled_train
|
75 |
+
return module
|
76 |
+
|
77 |
+
|
78 |
+
class LLMConfig(AutoConfig):
|
79 |
+
model_type = "20b"
|
80 |
+
|
81 |
+
|
82 |
+
class BaseMLLM(PreTrainedModel):
|
83 |
+
config_class = VideoChatEConfig
|
84 |
+
def __init__(self, config,_tokenizer=None):
|
85 |
+
# super().__init__(config)
|
86 |
+
self.model_config = config.model_config
|
87 |
+
self.tokenizer = _tokenizer
|
88 |
+
|
89 |
+
config.cg_opt = None
|
90 |
+
config.model_config = None
|
91 |
+
config.model_tokenizer = None
|
92 |
+
super().__init__(config)
|
93 |
+
self.build_vision_encoder()
|
94 |
+
self.build_llm()
|
95 |
+
self.build_bridge()
|
96 |
+
self.build_loss()
|
97 |
+
|
98 |
+
self.load_pretrained_weights()
|
99 |
+
try:
|
100 |
+
if config.build_decoder:
|
101 |
+
self.cg_opt = config.cg_opt
|
102 |
+
self.build_bbox_decoder()
|
103 |
+
self.build_sam()
|
104 |
+
self.build_CGDETR()
|
105 |
+
except:
|
106 |
+
print("please install cgdetr and sam2 first")
|
107 |
+
logger.info(f'Length of tokenizer and resize embedding: {len(self.tokenizer)}')
|
108 |
+
|
109 |
+
|
110 |
+
def build_vision_encoder(self):
|
111 |
+
if 'internvideo2' in self.model_config.vision_encoder.name.lower():
|
112 |
+
encoder_name = self.model_config.vision_encoder.name
|
113 |
+
logger.info(f"Build vision_encoder: {encoder_name}")
|
114 |
+
if encoder_name == 'internvideo2-1B':
|
115 |
+
self.vision_encoder = pretrain_internvideo2_giant_patch14_224_clean(self.model_config)
|
116 |
+
|
117 |
+
else:
|
118 |
+
raise ValueError(f"Not implemented: {encoder_name}")
|
119 |
+
elif 'vit' in self.model_config.vision_encoder.name.lower():
|
120 |
+
self.vision_encoder = build_vit(self.model_config)
|
121 |
+
else:
|
122 |
+
raise NotImplementedError(self.model_config.vision_encoder.name)
|
123 |
+
|
124 |
+
if self.model_config.vision_encoder.vit_add_ln:
|
125 |
+
self.vision_layernorm = nn.LayerNorm(self.model_config.vision_encoder.encoder_embed_dim, eps=1e-12)
|
126 |
+
else:
|
127 |
+
self.vision_layernorm = nn.Identity()
|
128 |
+
|
129 |
+
self.freeze_vision_encoder = self.model_config.get("freeze_vision_encoder", False)
|
130 |
+
|
131 |
+
if self.freeze_vision_encoder:
|
132 |
+
logger.info("freeze vision encoder")
|
133 |
+
freeze_module(self.vision_encoder)
|
134 |
+
freeze_module(self.vision_layernorm)
|
135 |
+
|
136 |
+
def build_CGDETR(self):
|
137 |
+
self.cg_model, self.cg_criterion = build_cgdetr_model()
|
138 |
+
|
139 |
+
def build_bridge(self):
|
140 |
+
# ViT to LM: 1792 -> 6656 NOTE 768 is qformer dim
|
141 |
+
self.project_up = nn.Linear(768, self.lm.config.hidden_size) # whether bias is needed?
|
142 |
+
# LM to ViT: 6656 -> 1792
|
143 |
+
self.project_down = nn.Linear(self.lm.config.hidden_size, 768)
|
144 |
+
|
145 |
+
if 'qformer' in self.model_config.bridge.name.lower():
|
146 |
+
from transformers import BertTokenizer
|
147 |
+
self.qformer_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", truncation_side="left")
|
148 |
+
self.qformer_tokenizer.add_special_tokens({"bos_token": "[DEC]"})
|
149 |
+
self.qformer_tokenizer.padding_side = "left"
|
150 |
+
if self.model_config.bridge.name == 'qformer':
|
151 |
+
self.qformer, self.query_tokens = build_qformer(
|
152 |
+
self.model_config.bridge.num_query_token, self.model_config.vision_encoder.encoder_embed_dim,
|
153 |
+
qformer_hidden_dropout_prob=self.model_config.bridge.qformer_hidden_dropout_prob,
|
154 |
+
qformer_attention_probs_dropout_prob=self.model_config.bridge.qformer_attention_probs_dropout_prob,
|
155 |
+
qformer_drop_path_rate=self.model_config.bridge.qformer_drop_path_rate,
|
156 |
+
)
|
157 |
+
elif self.model_config.bridge.name == 'causal_qformer':
|
158 |
+
self.qformer, self.query_tokens = build_causal_qformer(
|
159 |
+
self.model_config.bridge.num_query_token, self.model_config.vision_encoder.encoder_embed_dim,
|
160 |
+
qformer_hidden_dropout_prob=self.model_config.bridge.qformer_hidden_dropout_prob,
|
161 |
+
qformer_attention_probs_dropout_prob=self.model_config.bridge.qformer_attention_probs_dropout_prob
|
162 |
+
)
|
163 |
+
self.qformer.resize_token_embeddings(len(self.qformer_tokenizer))
|
164 |
+
self.qformer.cls = None
|
165 |
+
self.extra_num_query_token = self.model_config.bridge.extra_num_query_token
|
166 |
+
if self.model_config.bridge.extra_num_query_token > 0:
|
167 |
+
logger.info(f"Add extra {self.model_config.bridge.extra_num_query_token} tokens in QFormer")
|
168 |
+
self.extra_query_tokens = nn.Parameter(
|
169 |
+
torch.zeros(1, self.model_config.bridge.extra_num_query_token, self.query_tokens.shape[-1])
|
170 |
+
)
|
171 |
+
|
172 |
+
self.freeze_bridge = self.model_config.get("freeze_bridge", False)
|
173 |
+
if self.freeze_bridge:
|
174 |
+
logger.info("freeze bridge")
|
175 |
+
freeze_module(self.qformer)
|
176 |
+
self.query_tokens.requires_grad = False
|
177 |
+
|
178 |
+
def build_llm(self):
|
179 |
+
self.lm_name = self.model_config.llm.name
|
180 |
+
if self.model_config.llm.name == "vicuna1.5_7b":
|
181 |
+
self.lm = LlamaForCausalLM.from_pretrained(self.model_config.llm.pretrained_llm_path)
|
182 |
+
self.lm.gradient_checkpointing = self.model_config.llm.get("use_llama_gradient_checkpointing", True)
|
183 |
+
elif self.model_config.llm.name == 'mistral_7b':
|
184 |
+
from transformers import AutoModelForCausalLM
|
185 |
+
|
186 |
+
config = AutoConfig.from_pretrained(
|
187 |
+
self.model_config.llm.pretrained_llm_path,
|
188 |
+
torch_dtype=torch.bfloat16,
|
189 |
+
# attn_implementation="flash_attention_2",
|
190 |
+
)
|
191 |
+
self.lm = AutoModelForCausalLM.from_config(config)
|
192 |
+
elif self.model_config.llm.name == 'internlm_20b':
|
193 |
+
from transformers import AutoModelForCausalLM
|
194 |
+
self.lm = AutoModelForCausalLM.from_pretrained(
|
195 |
+
self.model_config.llm.pretrained_llm_path,
|
196 |
+
torch_dtype=torch.bfloat16,
|
197 |
+
trust_remote_code=True,
|
198 |
+
)
|
199 |
+
self.lm.gradient_checkpointing = True
|
200 |
+
self.lm._set_gradient_checkpointing()
|
201 |
+
else:
|
202 |
+
raise NotImplementedError(self.model_config.llm.name)
|
203 |
+
|
204 |
+
num_new_tokens = len(special_tokens)
|
205 |
+
self.lm.resize_token_embeddings(len(self.tokenizer))
|
206 |
+
|
207 |
+
input_embeddings = self.lm.get_input_embeddings().weight.data
|
208 |
+
output_embeddings = self.lm.get_output_embeddings().weight.data
|
209 |
+
|
210 |
+
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
|
211 |
+
dim=0, keepdim=True)
|
212 |
+
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
|
213 |
+
dim=0, keepdim=True)
|
214 |
+
|
215 |
+
input_embeddings[-num_new_tokens:] = input_embeddings_avg
|
216 |
+
output_embeddings[-num_new_tokens:] = output_embeddings_avg
|
217 |
+
|
218 |
+
self.model_config.token_at_ids = self.tokenizer.convert_tokens_to_ids([BOX_START])[0]
|
219 |
+
self.freeze_llm = self.model_config.get("freeze_llm", True)
|
220 |
+
logger.info(f'freeze_llm: {self.freeze_llm}')
|
221 |
+
if self.freeze_llm:
|
222 |
+
logger.info("freeze llm")
|
223 |
+
freeze_module(self.lm)
|
224 |
+
|
225 |
+
if self.model_config.llm.use_lora:
|
226 |
+
self.use_lora = True
|
227 |
+
from peft import get_peft_model, LoraConfig, TaskType
|
228 |
+
logger.info("Use lora")
|
229 |
+
if self.model_config.llm.name == 'internlm_20b':
|
230 |
+
peft_config = LoraConfig(
|
231 |
+
task_type=TaskType.CAUSAL_LM, inference_mode=False,
|
232 |
+
r=self.model_config.llm.lora_r, lora_alpha=self.model_config.llm.lora_alpha, lora_dropout=self.model_config.llm.lora_dropout,
|
233 |
+
target_modules=['wqkv', 'wo', 'w1', 'w2', 'w3', 'output']
|
234 |
+
)
|
235 |
+
else:
|
236 |
+
peft_config = LoraConfig(
|
237 |
+
task_type=TaskType.CAUSAL_LM, inference_mode=False,
|
238 |
+
r=self.model_config.llm.lora_r, lora_alpha=self.model_config.llm.lora_alpha, lora_dropout=self.model_config.llm.lora_dropout,
|
239 |
+
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
|
240 |
+
"gate_proj", "up_proj", "down_proj", "lm_head"]
|
241 |
+
)
|
242 |
+
|
243 |
+
self.lm = get_peft_model(self.lm, peft_config)
|
244 |
+
self.lm.enable_input_require_grads()
|
245 |
+
self.lm.print_trainable_parameters()
|
246 |
+
|
247 |
+
if self.model_config.get("freeze_lora", False):
|
248 |
+
logger.info("freeze lora")
|
249 |
+
freeze_module(self.lm)
|
250 |
+
self.lm.print_trainable_parameters()
|
251 |
+
|
252 |
+
else:
|
253 |
+
self.use_lora = False
|
254 |
+
|
255 |
+
def add_lora(self):
|
256 |
+
if self.model_config.llm.use_lora:
|
257 |
+
self.use_lora = True
|
258 |
+
from peft import get_peft_model, LoraConfig, TaskType
|
259 |
+
logger.info("Use lora")
|
260 |
+
if self.model_config.llm.name == 'internlm_20b':
|
261 |
+
peft_config = LoraConfig(
|
262 |
+
task_type=TaskType.CAUSAL_LM, inference_mode=False,
|
263 |
+
r=self.model_config.llm.lora_r, lora_alpha=self.model_config.llm.lora_alpha, lora_dropout=self.model_config.llm.lora_dropout,
|
264 |
+
target_modules=['wqkv', 'wo', 'w1', 'w2', 'w3', 'output']
|
265 |
+
)
|
266 |
+
else:
|
267 |
+
peft_config = LoraConfig(
|
268 |
+
task_type=TaskType.CAUSAL_LM, inference_mode=False,
|
269 |
+
r=self.model_config.llm.lora_r, lora_alpha=self.model_config.llm.lora_alpha, lora_dropout=self.model_config.llm.lora_dropout,
|
270 |
+
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
|
271 |
+
"gate_proj", "up_proj", "down_proj", "lm_head"]
|
272 |
+
)
|
273 |
+
|
274 |
+
self.lm = get_peft_model(self.lm, peft_config)
|
275 |
+
self.lm.enable_input_require_grads()
|
276 |
+
self.lm.print_trainable_parameters()
|
277 |
+
|
278 |
+
if self.model_config.get("freeze_lora", False):
|
279 |
+
logger.info("freeze lora")
|
280 |
+
freeze_module(self.lm)
|
281 |
+
self.lm.print_trainable_parameters()
|
282 |
+
|
283 |
+
else:
|
284 |
+
self.use_lora = False
|
285 |
+
|
286 |
+
def add_tokens(self):
|
287 |
+
num_new_tokens = len(special_tokens)
|
288 |
+
self.lm.resize_token_embeddings(len(self.tokenizer))
|
289 |
+
|
290 |
+
input_embeddings = self.lm.get_input_embeddings().weight.data
|
291 |
+
output_embeddings = self.lm.get_output_embeddings().weight.data
|
292 |
+
|
293 |
+
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
|
294 |
+
dim=0, keepdim=True)
|
295 |
+
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
|
296 |
+
dim=0, keepdim=True)
|
297 |
+
print(self.lm.get_input_embeddings().weight.data.shape)
|
298 |
+
input_embeddings[-num_new_tokens:] = input_embeddings_avg
|
299 |
+
output_embeddings[-num_new_tokens:] = output_embeddings_avg
|
300 |
+
|
301 |
+
self.model_config.token_at_ids = self.tokenizer.convert_tokens_to_ids([BOX_START])[0]
|
302 |
+
|
303 |
+
def build_loss(self):
|
304 |
+
self.use_vision_regression_loss = self.model_config.loss.get("use_vision_regression_loss", False)
|
305 |
+
self.use_bbox_loss = self.model_config.loss.get("add_bbox_loss", False)
|
306 |
+
self.use_mask_loss = self.model_config.loss.get("use_mask_loss", False)
|
307 |
+
self.use_temporal_loss = self.model_config.loss.get('use_temporal_loss', False)
|
308 |
+
if self.use_vision_regression_loss:
|
309 |
+
self.image_loss_fct = MSELoss()
|
310 |
+
|
311 |
+
|
312 |
+
def load_pretrained_weights(self):
|
313 |
+
if self.model_config.pretrained_paths.get('pretrained_vit_qformer_path', None):
|
314 |
+
if 'safetensor' in self.model_config.pretrained_paths.pretrained_vit_qformer_path:
|
315 |
+
from safetensors import safe_open
|
316 |
+
from safetensors.torch import save_file
|
317 |
+
state_dict = {}
|
318 |
+
with safe_open(self.model_config.pretrained_paths.pretrained_vit_qformer_path, framework="pt", device="cpu") as f:
|
319 |
+
for key in f.keys():
|
320 |
+
state_dict[key] = f.get_tensor(key)
|
321 |
+
else:
|
322 |
+
state_dict = torch.load(self.model_config.pretrained_paths.pretrained_vit_qformer_path, map_location="cpu")
|
323 |
+
if "model" in state_dict.keys():
|
324 |
+
state_dict = state_dict["model"]
|
325 |
+
elif "module" in state_dict.keys():
|
326 |
+
state_dict = state_dict["module"] # for deepspeed
|
327 |
+
self.check_temp_emb(state_dict)
|
328 |
+
msg = self.load_state_dict(state_dict, strict=False)
|
329 |
+
print('Loading vit: ', msg)
|
330 |
+
logger.info(f"Load ViT and QFormer from {self.model_config.pretrained_paths.pretrained_vit_qformer_path}: {msg}")
|
331 |
+
|
332 |
+
if self.model_config.pretrained_paths.get('pretrained_videochat2', None):
|
333 |
+
state_dict = torch.load(self.model_config.pretrained_paths.pretrained_videochat2, map_location="cpu")
|
334 |
+
|
335 |
+
new_state_dict = {}
|
336 |
+
for k in state_dict.keys():
|
337 |
+
if 'bert.embeddings' not in k:
|
338 |
+
new_state_dict[k] = state_dict[k]
|
339 |
+
state_dict = new_state_dict
|
340 |
+
# self.check_temp_emb(state_dict)
|
341 |
+
msg = self.load_state_dict(state_dict, strict=False)
|
342 |
+
print('Loading videochat2: ', msg)
|
343 |
+
|
344 |
+
|
345 |
+
def check_temp_emb(self, state_dict):
|
346 |
+
old_num_frames = self.model_config.vision_encoder.get('origin_num_frames', None)
|
347 |
+
new_num_frames = self.model_config.vision_encoder.num_frames
|
348 |
+
if old_num_frames is not None and old_num_frames != new_num_frames:
|
349 |
+
logger.info(f"interpolate_pos_embed_internvideo2 to {new_num_frames} (origin_num_frames={old_num_frames})!!!")
|
350 |
+
a = len(state_dict)
|
351 |
+
interpolate_pos_embed_internvideo2_new(state_dict, self.vision_encoder, orig_t_size=4)
|
352 |
+
assert a == len(state_dict), state_dict.keys()
|
353 |
+
|
354 |
+
def build_bbox_decoder(self):
|
355 |
+
self.loc_encoder = nn.Sequential(
|
356 |
+
nn.Linear(4, self.model_config.llm.hidden_size // 2, dtype=torch.bfloat16),
|
357 |
+
nn.ReLU(),
|
358 |
+
nn.Linear(self.model_config.llm.hidden_size // 2, self.model_config.llm.hidden_size, dtype=torch.bfloat16),
|
359 |
+
)
|
360 |
+
|
361 |
+
self.loc_decoder = nn.Sequential(
|
362 |
+
nn.Linear(self.model_config.llm.hidden_size, self.model_config.llm.hidden_size // 2, dtype=torch.bfloat16),
|
363 |
+
nn.ReLU(),
|
364 |
+
nn.Linear(self.model_config.llm.hidden_size // 2, 4, dtype=torch.bfloat16)
|
365 |
+
)
|
366 |
+
self._initialize_bbox_weights()
|
367 |
+
|
368 |
+
def _initialize_bbox_weights(self):
|
369 |
+
return
|
370 |
+
|
371 |
+
def build_sam(self):
|
372 |
+
sam2_checkpoint = "/cpfs01/user/heyinan/checkpoints/sam2_hiera_large.pt"
|
373 |
+
model_cfg = "sam2_hiera_l.yaml"
|
374 |
+
predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device=self.lm.device)
|
375 |
+
|
376 |
+
self.sam = predictor
|
377 |
+
freeze_module(self.sam)
|
378 |
+
|
379 |
+
|
380 |
+
@property
|
381 |
+
def dtype(self):
|
382 |
+
return self.lm.dtype
|
383 |
+
|
384 |
+
|
385 |
+
@property
|
386 |
+
def device(self):
|
387 |
+
return self.lm.device
|
modeling_qformer.py
ADDED
@@ -0,0 +1,1264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
* Copyright (c) 2023, salesforce.com, inc.
|
3 |
+
* All rights reserved.
|
4 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
* For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
* By Junnan Li
|
7 |
+
* Based on huggingface code base
|
8 |
+
* https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert
|
9 |
+
"""
|
10 |
+
import logging
|
11 |
+
import math
|
12 |
+
import os
|
13 |
+
import warnings
|
14 |
+
from dataclasses import dataclass
|
15 |
+
from typing import Optional, Tuple, Dict, Any
|
16 |
+
|
17 |
+
import torch
|
18 |
+
from torch import Tensor, device, dtype, nn
|
19 |
+
import torch.utils.checkpoint
|
20 |
+
from torch import nn
|
21 |
+
from torch.nn import CrossEntropyLoss
|
22 |
+
import torch.nn.functional as F
|
23 |
+
|
24 |
+
from timm.models.layers import drop_path
|
25 |
+
from transformers.activations import ACT2FN
|
26 |
+
from transformers.file_utils import (
|
27 |
+
ModelOutput,
|
28 |
+
)
|
29 |
+
from transformers.modeling_outputs import (
|
30 |
+
BaseModelOutputWithPastAndCrossAttentions,
|
31 |
+
BaseModelOutputWithPoolingAndCrossAttentions,
|
32 |
+
CausalLMOutputWithCrossAttentions,
|
33 |
+
MaskedLMOutput,
|
34 |
+
MultipleChoiceModelOutput,
|
35 |
+
NextSentencePredictorOutput,
|
36 |
+
QuestionAnsweringModelOutput,
|
37 |
+
SequenceClassifierOutput,
|
38 |
+
TokenClassifierOutput,
|
39 |
+
)
|
40 |
+
from transformers.modeling_utils import (
|
41 |
+
PreTrainedModel,
|
42 |
+
apply_chunking_to_forward,
|
43 |
+
find_pruneable_heads_and_indices,
|
44 |
+
prune_linear_layer,
|
45 |
+
)
|
46 |
+
from transformers.models.bert.configuration_bert import BertConfig
|
47 |
+
|
48 |
+
import logging
|
49 |
+
logger = logging.getLogger(__name__)
|
50 |
+
|
51 |
+
|
52 |
+
class BertEmbeddings(nn.Module):
|
53 |
+
"""Construct the embeddings from word and position embeddings."""
|
54 |
+
|
55 |
+
def __init__(self, config):
|
56 |
+
super().__init__()
|
57 |
+
self.word_embeddings = nn.Embedding(
|
58 |
+
config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id
|
59 |
+
)
|
60 |
+
self.position_embeddings = nn.Embedding(
|
61 |
+
config.max_position_embeddings, config.hidden_size
|
62 |
+
)
|
63 |
+
|
64 |
+
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
|
65 |
+
# any TensorFlow checkpoint file
|
66 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
67 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
68 |
+
|
69 |
+
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
70 |
+
self.register_buffer(
|
71 |
+
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))
|
72 |
+
)
|
73 |
+
self.position_embedding_type = getattr(
|
74 |
+
config, "position_embedding_type", "absolute"
|
75 |
+
)
|
76 |
+
|
77 |
+
self.config = config
|
78 |
+
|
79 |
+
def forward(
|
80 |
+
self,
|
81 |
+
input_ids=None,
|
82 |
+
position_ids=None,
|
83 |
+
query_embeds=None,
|
84 |
+
past_key_values_length=0,
|
85 |
+
):
|
86 |
+
if input_ids is not None:
|
87 |
+
seq_length = input_ids.size()[1]
|
88 |
+
else:
|
89 |
+
seq_length = 0
|
90 |
+
|
91 |
+
if position_ids is None:
|
92 |
+
position_ids = self.position_ids[
|
93 |
+
:, past_key_values_length : seq_length + past_key_values_length
|
94 |
+
].clone()
|
95 |
+
|
96 |
+
if input_ids is not None:
|
97 |
+
embeddings = self.word_embeddings(input_ids)
|
98 |
+
if self.position_embedding_type == "absolute":
|
99 |
+
position_embeddings = self.position_embeddings(position_ids)
|
100 |
+
embeddings = embeddings + position_embeddings
|
101 |
+
|
102 |
+
if query_embeds is not None:
|
103 |
+
embeddings = torch.cat((query_embeds, embeddings), dim=1)
|
104 |
+
else:
|
105 |
+
embeddings = query_embeds
|
106 |
+
|
107 |
+
embeddings = self.LayerNorm(embeddings)
|
108 |
+
embeddings = self.dropout(embeddings)
|
109 |
+
return embeddings
|
110 |
+
|
111 |
+
|
112 |
+
class BertSelfAttention(nn.Module):
|
113 |
+
def __init__(self, config, is_cross_attention):
|
114 |
+
super().__init__()
|
115 |
+
self.config = config
|
116 |
+
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
|
117 |
+
config, "embedding_size"
|
118 |
+
):
|
119 |
+
raise ValueError(
|
120 |
+
"The hidden size (%d) is not a multiple of the number of attention "
|
121 |
+
"heads (%d)" % (config.hidden_size, config.num_attention_heads)
|
122 |
+
)
|
123 |
+
|
124 |
+
self.num_attention_heads = config.num_attention_heads
|
125 |
+
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
126 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
127 |
+
|
128 |
+
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
129 |
+
if is_cross_attention:
|
130 |
+
self.key = nn.Linear(config.encoder_width, self.all_head_size)
|
131 |
+
self.value = nn.Linear(config.encoder_width, self.all_head_size)
|
132 |
+
else:
|
133 |
+
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
134 |
+
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
135 |
+
|
136 |
+
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
137 |
+
self.position_embedding_type = getattr(
|
138 |
+
config, "position_embedding_type", "absolute"
|
139 |
+
)
|
140 |
+
if (
|
141 |
+
self.position_embedding_type == "relative_key"
|
142 |
+
or self.position_embedding_type == "relative_key_query"
|
143 |
+
):
|
144 |
+
self.max_position_embeddings = config.max_position_embeddings
|
145 |
+
self.distance_embedding = nn.Embedding(
|
146 |
+
2 * config.max_position_embeddings - 1, self.attention_head_size
|
147 |
+
)
|
148 |
+
self.save_attention = False
|
149 |
+
|
150 |
+
def save_attn_gradients(self, attn_gradients):
|
151 |
+
self.attn_gradients = attn_gradients
|
152 |
+
|
153 |
+
def get_attn_gradients(self):
|
154 |
+
return self.attn_gradients
|
155 |
+
|
156 |
+
def save_attention_map(self, attention_map):
|
157 |
+
self.attention_map = attention_map
|
158 |
+
|
159 |
+
def get_attention_map(self):
|
160 |
+
return self.attention_map
|
161 |
+
|
162 |
+
def transpose_for_scores(self, x):
|
163 |
+
new_x_shape = x.size()[:-1] + (
|
164 |
+
self.num_attention_heads,
|
165 |
+
self.attention_head_size,
|
166 |
+
)
|
167 |
+
x = x.view(*new_x_shape)
|
168 |
+
return x.permute(0, 2, 1, 3)
|
169 |
+
|
170 |
+
def forward(
|
171 |
+
self,
|
172 |
+
hidden_states,
|
173 |
+
attention_mask=None,
|
174 |
+
head_mask=None,
|
175 |
+
encoder_hidden_states=None,
|
176 |
+
encoder_attention_mask=None,
|
177 |
+
past_key_value=None,
|
178 |
+
output_attentions=False,
|
179 |
+
):
|
180 |
+
|
181 |
+
# If this is instantiated as a cross-attention module, the keys
|
182 |
+
# and values come from an encoder; the attention mask needs to be
|
183 |
+
# such that the encoder's padding tokens are not attended to.
|
184 |
+
is_cross_attention = encoder_hidden_states is not None
|
185 |
+
|
186 |
+
if is_cross_attention:
|
187 |
+
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
|
188 |
+
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
|
189 |
+
attention_mask = encoder_attention_mask
|
190 |
+
elif past_key_value is not None:
|
191 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
192 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
193 |
+
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
194 |
+
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
195 |
+
else:
|
196 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
197 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
198 |
+
|
199 |
+
mixed_query_layer = self.query(hidden_states)
|
200 |
+
|
201 |
+
query_layer = self.transpose_for_scores(mixed_query_layer)
|
202 |
+
|
203 |
+
past_key_value = (key_layer, value_layer)
|
204 |
+
|
205 |
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
206 |
+
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
207 |
+
|
208 |
+
if (
|
209 |
+
self.position_embedding_type == "relative_key"
|
210 |
+
or self.position_embedding_type == "relative_key_query"
|
211 |
+
):
|
212 |
+
seq_length = hidden_states.size()[1]
|
213 |
+
position_ids_l = torch.arange(
|
214 |
+
seq_length, dtype=torch.long, device=hidden_states.device
|
215 |
+
).view(-1, 1)
|
216 |
+
position_ids_r = torch.arange(
|
217 |
+
seq_length, dtype=torch.long, device=hidden_states.device
|
218 |
+
).view(1, -1)
|
219 |
+
distance = position_ids_l - position_ids_r
|
220 |
+
positional_embedding = self.distance_embedding(
|
221 |
+
distance + self.max_position_embeddings - 1
|
222 |
+
)
|
223 |
+
positional_embedding = positional_embedding.to(
|
224 |
+
dtype=query_layer.dtype
|
225 |
+
) # fp16 compatibility
|
226 |
+
|
227 |
+
if self.position_embedding_type == "relative_key":
|
228 |
+
relative_position_scores = torch.einsum(
|
229 |
+
"bhld,lrd->bhlr", query_layer, positional_embedding
|
230 |
+
)
|
231 |
+
attention_scores = attention_scores + relative_position_scores
|
232 |
+
elif self.position_embedding_type == "relative_key_query":
|
233 |
+
relative_position_scores_query = torch.einsum(
|
234 |
+
"bhld,lrd->bhlr", query_layer, positional_embedding
|
235 |
+
)
|
236 |
+
relative_position_scores_key = torch.einsum(
|
237 |
+
"bhrd,lrd->bhlr", key_layer, positional_embedding
|
238 |
+
)
|
239 |
+
attention_scores = (
|
240 |
+
attention_scores
|
241 |
+
+ relative_position_scores_query
|
242 |
+
+ relative_position_scores_key
|
243 |
+
)
|
244 |
+
|
245 |
+
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
246 |
+
if attention_mask is not None:
|
247 |
+
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
|
248 |
+
attention_scores = attention_scores + attention_mask
|
249 |
+
|
250 |
+
# Normalize the attention scores to probabilities.
|
251 |
+
attention_probs = nn.Softmax(dim=-1)(attention_scores)
|
252 |
+
|
253 |
+
if is_cross_attention and self.save_attention:
|
254 |
+
self.save_attention_map(attention_probs)
|
255 |
+
attention_probs.register_hook(self.save_attn_gradients)
|
256 |
+
|
257 |
+
# This is actually dropping out entire tokens to attend to, which might
|
258 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
259 |
+
attention_probs_dropped = self.dropout(attention_probs)
|
260 |
+
|
261 |
+
# Mask heads if we want to
|
262 |
+
if head_mask is not None:
|
263 |
+
attention_probs_dropped = attention_probs_dropped * head_mask
|
264 |
+
|
265 |
+
context_layer = torch.matmul(attention_probs_dropped, value_layer)
|
266 |
+
|
267 |
+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
268 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
269 |
+
context_layer = context_layer.view(*new_context_layer_shape)
|
270 |
+
|
271 |
+
outputs = (
|
272 |
+
(context_layer, attention_probs) if output_attentions else (context_layer,)
|
273 |
+
)
|
274 |
+
|
275 |
+
outputs = outputs + (past_key_value,)
|
276 |
+
return outputs
|
277 |
+
|
278 |
+
|
279 |
+
class DropPath(nn.Module):
|
280 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
281 |
+
"""
|
282 |
+
def __init__(self, drop_prob=None):
|
283 |
+
super(DropPath, self).__init__()
|
284 |
+
self.drop_prob = drop_prob
|
285 |
+
|
286 |
+
def forward(self, x):
|
287 |
+
return drop_path(x, self.drop_prob, self.training)
|
288 |
+
|
289 |
+
def extra_repr(self) -> str:
|
290 |
+
return 'p={}'.format(self.drop_prob)
|
291 |
+
|
292 |
+
|
293 |
+
class BertSelfOutput(nn.Module):
|
294 |
+
def __init__(self, config, drop_path=0.):
|
295 |
+
super().__init__()
|
296 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
297 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
298 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
299 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
300 |
+
|
301 |
+
def forward(self, hidden_states, input_tensor):
|
302 |
+
hidden_states = self.dense(hidden_states)
|
303 |
+
hidden_states = self.dropout(hidden_states)
|
304 |
+
hidden_states = self.drop_path(hidden_states)
|
305 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
306 |
+
return hidden_states
|
307 |
+
|
308 |
+
|
309 |
+
class BertAttention(nn.Module):
|
310 |
+
def __init__(self, config, is_cross_attention=False, drop_path=0.,):
|
311 |
+
super().__init__()
|
312 |
+
self.self = BertSelfAttention(config, is_cross_attention)
|
313 |
+
self.output = BertSelfOutput(config, drop_path=drop_path)
|
314 |
+
self.pruned_heads = set()
|
315 |
+
|
316 |
+
def prune_heads(self, heads):
|
317 |
+
if len(heads) == 0:
|
318 |
+
return
|
319 |
+
heads, index = find_pruneable_heads_and_indices(
|
320 |
+
heads,
|
321 |
+
self.self.num_attention_heads,
|
322 |
+
self.self.attention_head_size,
|
323 |
+
self.pruned_heads,
|
324 |
+
)
|
325 |
+
|
326 |
+
# Prune linear layers
|
327 |
+
self.self.query = prune_linear_layer(self.self.query, index)
|
328 |
+
self.self.key = prune_linear_layer(self.self.key, index)
|
329 |
+
self.self.value = prune_linear_layer(self.self.value, index)
|
330 |
+
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
|
331 |
+
|
332 |
+
# Update hyper params and store pruned heads
|
333 |
+
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
|
334 |
+
self.self.all_head_size = (
|
335 |
+
self.self.attention_head_size * self.self.num_attention_heads
|
336 |
+
)
|
337 |
+
self.pruned_heads = self.pruned_heads.union(heads)
|
338 |
+
|
339 |
+
def forward(
|
340 |
+
self,
|
341 |
+
hidden_states,
|
342 |
+
attention_mask=None,
|
343 |
+
head_mask=None,
|
344 |
+
encoder_hidden_states=None,
|
345 |
+
encoder_attention_mask=None,
|
346 |
+
past_key_value=None,
|
347 |
+
output_attentions=False,
|
348 |
+
):
|
349 |
+
self_outputs = self.self(
|
350 |
+
hidden_states,
|
351 |
+
attention_mask,
|
352 |
+
head_mask,
|
353 |
+
encoder_hidden_states,
|
354 |
+
encoder_attention_mask,
|
355 |
+
past_key_value,
|
356 |
+
output_attentions,
|
357 |
+
)
|
358 |
+
attention_output = self.output(self_outputs[0], hidden_states)
|
359 |
+
|
360 |
+
outputs = (attention_output,) + self_outputs[
|
361 |
+
1:
|
362 |
+
] # add attentions if we output them
|
363 |
+
return outputs
|
364 |
+
|
365 |
+
|
366 |
+
class BertIntermediate(nn.Module):
|
367 |
+
def __init__(self, config):
|
368 |
+
super().__init__()
|
369 |
+
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
370 |
+
if isinstance(config.hidden_act, str):
|
371 |
+
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
372 |
+
else:
|
373 |
+
self.intermediate_act_fn = config.hidden_act
|
374 |
+
|
375 |
+
def forward(self, hidden_states):
|
376 |
+
hidden_states = self.dense(hidden_states)
|
377 |
+
hidden_states = self.intermediate_act_fn(hidden_states)
|
378 |
+
return hidden_states
|
379 |
+
|
380 |
+
|
381 |
+
class BertOutput(nn.Module):
|
382 |
+
def __init__(self, config, drop_path=0.):
|
383 |
+
super().__init__()
|
384 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
385 |
+
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
386 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
387 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
388 |
+
|
389 |
+
def forward(self, hidden_states, input_tensor):
|
390 |
+
hidden_states = self.dense(hidden_states)
|
391 |
+
hidden_states = self.dropout(hidden_states)
|
392 |
+
hidden_states = self.drop_path(hidden_states)
|
393 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
394 |
+
return hidden_states
|
395 |
+
|
396 |
+
|
397 |
+
class BertLayer(nn.Module):
|
398 |
+
def __init__(self, config, layer_num):
|
399 |
+
super().__init__()
|
400 |
+
self.config = config
|
401 |
+
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
402 |
+
self.seq_len_dim = 1
|
403 |
+
drop_path = config.drop_path_list[layer_num]
|
404 |
+
self.attention = BertAttention(config, drop_path=drop_path)
|
405 |
+
self.layer_num = layer_num
|
406 |
+
if (
|
407 |
+
self.config.add_cross_attention
|
408 |
+
and layer_num % self.config.cross_attention_freq == 0
|
409 |
+
):
|
410 |
+
self.crossattention = BertAttention(
|
411 |
+
config, is_cross_attention=self.config.add_cross_attention,
|
412 |
+
drop_path=drop_path
|
413 |
+
)
|
414 |
+
self.has_cross_attention = True
|
415 |
+
else:
|
416 |
+
self.has_cross_attention = False
|
417 |
+
self.intermediate = BertIntermediate(config)
|
418 |
+
self.output = BertOutput(config, drop_path=drop_path)
|
419 |
+
|
420 |
+
self.intermediate_query = BertIntermediate(config)
|
421 |
+
self.output_query = BertOutput(config, drop_path=drop_path)
|
422 |
+
|
423 |
+
def forward(
|
424 |
+
self,
|
425 |
+
hidden_states,
|
426 |
+
attention_mask=None,
|
427 |
+
head_mask=None,
|
428 |
+
encoder_hidden_states=None,
|
429 |
+
encoder_attention_mask=None,
|
430 |
+
past_key_value=None,
|
431 |
+
output_attentions=False,
|
432 |
+
query_length=0,
|
433 |
+
):
|
434 |
+
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
435 |
+
self_attn_past_key_value = (
|
436 |
+
past_key_value[:2] if past_key_value is not None else None
|
437 |
+
)
|
438 |
+
self_attention_outputs = self.attention(
|
439 |
+
hidden_states,
|
440 |
+
attention_mask,
|
441 |
+
head_mask,
|
442 |
+
output_attentions=output_attentions,
|
443 |
+
past_key_value=self_attn_past_key_value,
|
444 |
+
)
|
445 |
+
attention_output = self_attention_outputs[0]
|
446 |
+
outputs = self_attention_outputs[1:-1]
|
447 |
+
|
448 |
+
present_key_value = self_attention_outputs[-1]
|
449 |
+
|
450 |
+
if query_length > 0:
|
451 |
+
query_attention_output = attention_output[:, :query_length, :]
|
452 |
+
|
453 |
+
if self.has_cross_attention:
|
454 |
+
assert (
|
455 |
+
encoder_hidden_states is not None
|
456 |
+
), "encoder_hidden_states must be given for cross-attention layers"
|
457 |
+
cross_attention_outputs = self.crossattention(
|
458 |
+
query_attention_output,
|
459 |
+
attention_mask,
|
460 |
+
head_mask,
|
461 |
+
encoder_hidden_states,
|
462 |
+
encoder_attention_mask,
|
463 |
+
output_attentions=output_attentions,
|
464 |
+
)
|
465 |
+
query_attention_output = cross_attention_outputs[0]
|
466 |
+
outputs = (
|
467 |
+
outputs + cross_attention_outputs[1:-1]
|
468 |
+
) # add cross attentions if we output attention weights
|
469 |
+
|
470 |
+
layer_output = apply_chunking_to_forward(
|
471 |
+
self.feed_forward_chunk_query,
|
472 |
+
self.chunk_size_feed_forward,
|
473 |
+
self.seq_len_dim,
|
474 |
+
query_attention_output,
|
475 |
+
)
|
476 |
+
if attention_output.shape[1] > query_length:
|
477 |
+
layer_output_text = apply_chunking_to_forward(
|
478 |
+
self.feed_forward_chunk,
|
479 |
+
self.chunk_size_feed_forward,
|
480 |
+
self.seq_len_dim,
|
481 |
+
attention_output[:, query_length:, :],
|
482 |
+
)
|
483 |
+
layer_output = torch.cat([layer_output, layer_output_text], dim=1)
|
484 |
+
else:
|
485 |
+
layer_output = apply_chunking_to_forward(
|
486 |
+
self.feed_forward_chunk,
|
487 |
+
self.chunk_size_feed_forward,
|
488 |
+
self.seq_len_dim,
|
489 |
+
attention_output,
|
490 |
+
)
|
491 |
+
outputs = (layer_output,) + outputs
|
492 |
+
|
493 |
+
outputs = outputs + (present_key_value,)
|
494 |
+
|
495 |
+
return outputs
|
496 |
+
|
497 |
+
def feed_forward_chunk(self, attention_output):
|
498 |
+
intermediate_output = self.intermediate(attention_output)
|
499 |
+
layer_output = self.output(intermediate_output, attention_output)
|
500 |
+
return layer_output
|
501 |
+
|
502 |
+
def feed_forward_chunk_query(self, attention_output):
|
503 |
+
intermediate_output = self.intermediate_query(attention_output)
|
504 |
+
layer_output = self.output_query(intermediate_output, attention_output)
|
505 |
+
return layer_output
|
506 |
+
|
507 |
+
|
508 |
+
class BertEncoder(nn.Module):
|
509 |
+
def __init__(self, config):
|
510 |
+
super().__init__()
|
511 |
+
self.config = config
|
512 |
+
self.layer = nn.ModuleList(
|
513 |
+
[BertLayer(config, i) for i in range(config.num_hidden_layers)]
|
514 |
+
)
|
515 |
+
|
516 |
+
def forward(
|
517 |
+
self,
|
518 |
+
hidden_states,
|
519 |
+
attention_mask=None,
|
520 |
+
head_mask=None,
|
521 |
+
encoder_hidden_states=None,
|
522 |
+
encoder_attention_mask=None,
|
523 |
+
past_key_values=None,
|
524 |
+
use_cache=None,
|
525 |
+
output_attentions=False,
|
526 |
+
output_hidden_states=False,
|
527 |
+
return_dict=True,
|
528 |
+
query_length=0,
|
529 |
+
):
|
530 |
+
all_hidden_states = () if output_hidden_states else None
|
531 |
+
all_self_attentions = () if output_attentions else None
|
532 |
+
all_cross_attentions = (
|
533 |
+
() if output_attentions and self.config.add_cross_attention else None
|
534 |
+
)
|
535 |
+
|
536 |
+
next_decoder_cache = () if use_cache else None
|
537 |
+
|
538 |
+
for i in range(self.config.num_hidden_layers):
|
539 |
+
layer_module = self.layer[i]
|
540 |
+
if output_hidden_states:
|
541 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
542 |
+
|
543 |
+
layer_head_mask = head_mask[i] if head_mask is not None else None
|
544 |
+
past_key_value = past_key_values[i] if past_key_values is not None else None
|
545 |
+
|
546 |
+
if getattr(self.config, "gradient_checkpointing", False) and self.training:
|
547 |
+
|
548 |
+
if use_cache:
|
549 |
+
logger.warn(
|
550 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
551 |
+
)
|
552 |
+
use_cache = False
|
553 |
+
|
554 |
+
def create_custom_forward(module):
|
555 |
+
def custom_forward(*inputs):
|
556 |
+
return module(
|
557 |
+
*inputs, past_key_value, output_attentions, query_length
|
558 |
+
)
|
559 |
+
|
560 |
+
return custom_forward
|
561 |
+
|
562 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
563 |
+
create_custom_forward(layer_module),
|
564 |
+
hidden_states,
|
565 |
+
attention_mask,
|
566 |
+
layer_head_mask,
|
567 |
+
encoder_hidden_states,
|
568 |
+
encoder_attention_mask,
|
569 |
+
)
|
570 |
+
else:
|
571 |
+
layer_outputs = layer_module(
|
572 |
+
hidden_states,
|
573 |
+
attention_mask,
|
574 |
+
layer_head_mask,
|
575 |
+
encoder_hidden_states,
|
576 |
+
encoder_attention_mask,
|
577 |
+
past_key_value,
|
578 |
+
output_attentions,
|
579 |
+
query_length,
|
580 |
+
)
|
581 |
+
|
582 |
+
hidden_states = layer_outputs[0]
|
583 |
+
if use_cache:
|
584 |
+
next_decoder_cache += (layer_outputs[-1],)
|
585 |
+
if output_attentions:
|
586 |
+
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
587 |
+
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
|
588 |
+
|
589 |
+
if output_hidden_states:
|
590 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
591 |
+
|
592 |
+
if not return_dict:
|
593 |
+
return tuple(
|
594 |
+
v
|
595 |
+
for v in [
|
596 |
+
hidden_states,
|
597 |
+
next_decoder_cache,
|
598 |
+
all_hidden_states,
|
599 |
+
all_self_attentions,
|
600 |
+
all_cross_attentions,
|
601 |
+
]
|
602 |
+
if v is not None
|
603 |
+
)
|
604 |
+
return BaseModelOutputWithPastAndCrossAttentions(
|
605 |
+
last_hidden_state=hidden_states,
|
606 |
+
past_key_values=next_decoder_cache,
|
607 |
+
hidden_states=all_hidden_states,
|
608 |
+
attentions=all_self_attentions,
|
609 |
+
cross_attentions=all_cross_attentions,
|
610 |
+
)
|
611 |
+
|
612 |
+
|
613 |
+
class BertPooler(nn.Module):
|
614 |
+
def __init__(self, config):
|
615 |
+
super().__init__()
|
616 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
617 |
+
self.activation = nn.Tanh()
|
618 |
+
|
619 |
+
def forward(self, hidden_states):
|
620 |
+
# We "pool" the model by simply taking the hidden state corresponding
|
621 |
+
# to the first token.
|
622 |
+
first_token_tensor = hidden_states[:, 0]
|
623 |
+
pooled_output = self.dense(first_token_tensor)
|
624 |
+
pooled_output = self.activation(pooled_output)
|
625 |
+
return pooled_output
|
626 |
+
|
627 |
+
|
628 |
+
class BertPredictionHeadTransform(nn.Module):
|
629 |
+
def __init__(self, config):
|
630 |
+
super().__init__()
|
631 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
632 |
+
if isinstance(config.hidden_act, str):
|
633 |
+
self.transform_act_fn = ACT2FN[config.hidden_act]
|
634 |
+
else:
|
635 |
+
self.transform_act_fn = config.hidden_act
|
636 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
637 |
+
|
638 |
+
def forward(self, hidden_states):
|
639 |
+
hidden_states = self.dense(hidden_states)
|
640 |
+
hidden_states = self.transform_act_fn(hidden_states)
|
641 |
+
hidden_states = self.LayerNorm(hidden_states)
|
642 |
+
return hidden_states
|
643 |
+
|
644 |
+
|
645 |
+
class BertLMPredictionHead(nn.Module):
|
646 |
+
def __init__(self, config):
|
647 |
+
super().__init__()
|
648 |
+
self.transform = BertPredictionHeadTransform(config)
|
649 |
+
|
650 |
+
# The output weights are the same as the input embeddings, but there is
|
651 |
+
# an output-only bias for each token.
|
652 |
+
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
653 |
+
|
654 |
+
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
655 |
+
|
656 |
+
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
657 |
+
self.decoder.bias = self.bias
|
658 |
+
|
659 |
+
def forward(self, hidden_states):
|
660 |
+
hidden_states = self.transform(hidden_states)
|
661 |
+
hidden_states = self.decoder(hidden_states)
|
662 |
+
return hidden_states
|
663 |
+
|
664 |
+
|
665 |
+
class BertOnlyMLMHead(nn.Module):
|
666 |
+
def __init__(self, config):
|
667 |
+
super().__init__()
|
668 |
+
self.predictions = BertLMPredictionHead(config)
|
669 |
+
|
670 |
+
def forward(self, sequence_output):
|
671 |
+
prediction_scores = self.predictions(sequence_output)
|
672 |
+
return prediction_scores
|
673 |
+
|
674 |
+
|
675 |
+
class BertPreTrainedModel(PreTrainedModel):
|
676 |
+
"""
|
677 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
678 |
+
models.
|
679 |
+
"""
|
680 |
+
|
681 |
+
config_class = BertConfig
|
682 |
+
base_model_prefix = "bert"
|
683 |
+
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
684 |
+
|
685 |
+
def _init_weights(self, module):
|
686 |
+
"""Initialize the weights"""
|
687 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
688 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
689 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
690 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
691 |
+
elif isinstance(module, nn.LayerNorm):
|
692 |
+
module.bias.data.zero_()
|
693 |
+
module.weight.data.fill_(1.0)
|
694 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
695 |
+
module.bias.data.zero_()
|
696 |
+
|
697 |
+
|
698 |
+
class BertModel(BertPreTrainedModel):
|
699 |
+
"""
|
700 |
+
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
|
701 |
+
cross-attention is added between the self-attention layers, following the architecture described in `Attention is
|
702 |
+
all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
|
703 |
+
Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
|
704 |
+
argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
|
705 |
+
input to the forward pass.
|
706 |
+
"""
|
707 |
+
|
708 |
+
def __init__(self, config, add_pooling_layer=False):
|
709 |
+
super().__init__(config)
|
710 |
+
self.config = config
|
711 |
+
|
712 |
+
self.embeddings = BertEmbeddings(config)
|
713 |
+
|
714 |
+
self.encoder = BertEncoder(config)
|
715 |
+
|
716 |
+
self.pooler = BertPooler(config) if add_pooling_layer else None
|
717 |
+
|
718 |
+
self.init_weights()
|
719 |
+
|
720 |
+
def get_input_embeddings(self):
|
721 |
+
return self.embeddings.word_embeddings
|
722 |
+
|
723 |
+
def set_input_embeddings(self, value):
|
724 |
+
self.embeddings.word_embeddings = value
|
725 |
+
|
726 |
+
def _prune_heads(self, heads_to_prune):
|
727 |
+
"""
|
728 |
+
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
729 |
+
class PreTrainedModel
|
730 |
+
"""
|
731 |
+
for layer, heads in heads_to_prune.items():
|
732 |
+
self.encoder.layer[layer].attention.prune_heads(heads)
|
733 |
+
|
734 |
+
def get_extended_attention_mask(
|
735 |
+
self,
|
736 |
+
attention_mask: Tensor,
|
737 |
+
input_shape: Tuple[int],
|
738 |
+
device: device,
|
739 |
+
is_decoder: bool,
|
740 |
+
has_query: bool = False,
|
741 |
+
) -> Tensor:
|
742 |
+
"""
|
743 |
+
Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
|
744 |
+
|
745 |
+
Arguments:
|
746 |
+
attention_mask (:obj:`torch.Tensor`):
|
747 |
+
Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
|
748 |
+
input_shape (:obj:`Tuple[int]`):
|
749 |
+
The shape of the input to the model.
|
750 |
+
device: (:obj:`torch.device`):
|
751 |
+
The device of the input to the model.
|
752 |
+
|
753 |
+
Returns:
|
754 |
+
:obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
|
755 |
+
"""
|
756 |
+
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
757 |
+
# ourselves in which case we just need to make it broadcastable to all heads.
|
758 |
+
if attention_mask.dim() == 3:
|
759 |
+
extended_attention_mask = attention_mask[:, None, :, :]
|
760 |
+
elif attention_mask.dim() == 2:
|
761 |
+
# Provided a padding mask of dimensions [batch_size, seq_length]
|
762 |
+
# - if the model is a decoder, apply a causal mask in addition to the padding mask
|
763 |
+
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
764 |
+
if is_decoder:
|
765 |
+
batch_size, seq_length = input_shape
|
766 |
+
|
767 |
+
seq_ids = torch.arange(seq_length, device=device)
|
768 |
+
causal_mask = (
|
769 |
+
seq_ids[None, None, :].repeat(batch_size, seq_length, 1)
|
770 |
+
<= seq_ids[None, :, None]
|
771 |
+
)
|
772 |
+
|
773 |
+
# add a prefix ones mask to the causal mask
|
774 |
+
# causal and attention masks must have same type with pytorch version < 1.3
|
775 |
+
causal_mask = causal_mask.to(attention_mask.dtype)
|
776 |
+
|
777 |
+
if causal_mask.shape[1] < attention_mask.shape[1]:
|
778 |
+
prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
|
779 |
+
if has_query: # UniLM style attention mask
|
780 |
+
causal_mask = torch.cat(
|
781 |
+
[
|
782 |
+
torch.zeros(
|
783 |
+
(batch_size, prefix_seq_len, seq_length),
|
784 |
+
device=device,
|
785 |
+
dtype=causal_mask.dtype,
|
786 |
+
),
|
787 |
+
causal_mask,
|
788 |
+
],
|
789 |
+
axis=1,
|
790 |
+
)
|
791 |
+
causal_mask = torch.cat(
|
792 |
+
[
|
793 |
+
torch.ones(
|
794 |
+
(batch_size, causal_mask.shape[1], prefix_seq_len),
|
795 |
+
device=device,
|
796 |
+
dtype=causal_mask.dtype,
|
797 |
+
),
|
798 |
+
causal_mask,
|
799 |
+
],
|
800 |
+
axis=-1,
|
801 |
+
)
|
802 |
+
extended_attention_mask = (
|
803 |
+
causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
|
804 |
+
)
|
805 |
+
else:
|
806 |
+
extended_attention_mask = attention_mask[:, None, None, :]
|
807 |
+
else:
|
808 |
+
raise ValueError(
|
809 |
+
"Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
|
810 |
+
input_shape, attention_mask.shape
|
811 |
+
)
|
812 |
+
)
|
813 |
+
|
814 |
+
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
815 |
+
# masked positions, this operation will create a tensor which is 0.0 for
|
816 |
+
# positions we want to attend and -10000.0 for masked positions.
|
817 |
+
# Since we are adding it to the raw scores before the softmax, this is
|
818 |
+
# effectively the same as removing these entirely.
|
819 |
+
extended_attention_mask = extended_attention_mask.to(
|
820 |
+
dtype=self.dtype
|
821 |
+
) # fp16 compatibility
|
822 |
+
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
823 |
+
return extended_attention_mask
|
824 |
+
|
825 |
+
def forward(
|
826 |
+
self,
|
827 |
+
input_ids=None,
|
828 |
+
attention_mask=None,
|
829 |
+
position_ids=None,
|
830 |
+
head_mask=None,
|
831 |
+
query_embeds=None,
|
832 |
+
encoder_hidden_states=None,
|
833 |
+
encoder_attention_mask=None,
|
834 |
+
past_key_values=None,
|
835 |
+
use_cache=None,
|
836 |
+
output_attentions=None,
|
837 |
+
output_hidden_states=None,
|
838 |
+
return_dict=None,
|
839 |
+
is_decoder=False,
|
840 |
+
):
|
841 |
+
r"""
|
842 |
+
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
843 |
+
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
844 |
+
the model is configured as a decoder.
|
845 |
+
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
846 |
+
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
847 |
+
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
|
848 |
+
- 1 for tokens that are **not masked**,
|
849 |
+
- 0 for tokens that are **masked**.
|
850 |
+
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
851 |
+
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
852 |
+
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
|
853 |
+
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
|
854 |
+
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
|
855 |
+
use_cache (:obj:`bool`, `optional`):
|
856 |
+
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
|
857 |
+
decoding (see :obj:`past_key_values`).
|
858 |
+
"""
|
859 |
+
output_attentions = (
|
860 |
+
output_attentions
|
861 |
+
if output_attentions is not None
|
862 |
+
else self.config.output_attentions
|
863 |
+
)
|
864 |
+
output_hidden_states = (
|
865 |
+
output_hidden_states
|
866 |
+
if output_hidden_states is not None
|
867 |
+
else self.config.output_hidden_states
|
868 |
+
)
|
869 |
+
return_dict = (
|
870 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
871 |
+
)
|
872 |
+
|
873 |
+
# use_cache = use_cache if use_cache is not None else self.config.use_cache
|
874 |
+
|
875 |
+
if input_ids is None:
|
876 |
+
assert (
|
877 |
+
query_embeds is not None
|
878 |
+
), "You have to specify query_embeds when input_ids is None"
|
879 |
+
|
880 |
+
# past_key_values_length
|
881 |
+
past_key_values_length = (
|
882 |
+
past_key_values[0][0].shape[2] - self.config.query_length
|
883 |
+
if past_key_values is not None
|
884 |
+
else 0
|
885 |
+
)
|
886 |
+
|
887 |
+
query_length = query_embeds.shape[1] if query_embeds is not None else 0
|
888 |
+
|
889 |
+
embedding_output = self.embeddings(
|
890 |
+
input_ids=input_ids,
|
891 |
+
position_ids=position_ids,
|
892 |
+
query_embeds=query_embeds,
|
893 |
+
past_key_values_length=past_key_values_length,
|
894 |
+
)
|
895 |
+
|
896 |
+
input_shape = embedding_output.size()[:-1]
|
897 |
+
batch_size, seq_length = input_shape
|
898 |
+
device = embedding_output.device
|
899 |
+
|
900 |
+
if attention_mask is None:
|
901 |
+
attention_mask = torch.ones(
|
902 |
+
((batch_size, seq_length + past_key_values_length)), device=device
|
903 |
+
)
|
904 |
+
|
905 |
+
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
906 |
+
# ourselves in which case we just need to make it broadcastable to all heads.
|
907 |
+
if is_decoder:
|
908 |
+
extended_attention_mask = self.get_extended_attention_mask(
|
909 |
+
attention_mask,
|
910 |
+
input_ids.shape,
|
911 |
+
device,
|
912 |
+
is_decoder,
|
913 |
+
has_query=(query_embeds is not None),
|
914 |
+
)
|
915 |
+
else:
|
916 |
+
extended_attention_mask = self.get_extended_attention_mask(
|
917 |
+
attention_mask, input_shape, device, is_decoder
|
918 |
+
)
|
919 |
+
|
920 |
+
# If a 2D or 3D attention mask is provided for the cross-attention
|
921 |
+
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
922 |
+
if encoder_hidden_states is not None:
|
923 |
+
if type(encoder_hidden_states) == list:
|
924 |
+
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[
|
925 |
+
0
|
926 |
+
].size()
|
927 |
+
else:
|
928 |
+
(
|
929 |
+
encoder_batch_size,
|
930 |
+
encoder_sequence_length,
|
931 |
+
_,
|
932 |
+
) = encoder_hidden_states.size()
|
933 |
+
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
934 |
+
|
935 |
+
if type(encoder_attention_mask) == list:
|
936 |
+
encoder_extended_attention_mask = [
|
937 |
+
self.invert_attention_mask(mask) for mask in encoder_attention_mask
|
938 |
+
]
|
939 |
+
elif encoder_attention_mask is None:
|
940 |
+
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
941 |
+
encoder_extended_attention_mask = self.invert_attention_mask(
|
942 |
+
encoder_attention_mask
|
943 |
+
)
|
944 |
+
else:
|
945 |
+
encoder_extended_attention_mask = self.invert_attention_mask(
|
946 |
+
encoder_attention_mask
|
947 |
+
)
|
948 |
+
else:
|
949 |
+
encoder_extended_attention_mask = None
|
950 |
+
|
951 |
+
# Prepare head mask if needed
|
952 |
+
# 1.0 in head_mask indicate we keep the head
|
953 |
+
# attention_probs has shape bsz x n_heads x N x N
|
954 |
+
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
955 |
+
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
956 |
+
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
957 |
+
|
958 |
+
encoder_outputs = self.encoder(
|
959 |
+
embedding_output,
|
960 |
+
attention_mask=extended_attention_mask,
|
961 |
+
head_mask=head_mask,
|
962 |
+
encoder_hidden_states=encoder_hidden_states,
|
963 |
+
encoder_attention_mask=encoder_extended_attention_mask,
|
964 |
+
past_key_values=past_key_values,
|
965 |
+
use_cache=use_cache,
|
966 |
+
output_attentions=output_attentions,
|
967 |
+
output_hidden_states=output_hidden_states,
|
968 |
+
return_dict=return_dict,
|
969 |
+
query_length=query_length,
|
970 |
+
)
|
971 |
+
sequence_output = encoder_outputs[0]
|
972 |
+
pooled_output = (
|
973 |
+
self.pooler(sequence_output) if self.pooler is not None else None
|
974 |
+
)
|
975 |
+
|
976 |
+
if not return_dict:
|
977 |
+
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
978 |
+
|
979 |
+
return BaseModelOutputWithPoolingAndCrossAttentions(
|
980 |
+
last_hidden_state=sequence_output,
|
981 |
+
pooler_output=pooled_output,
|
982 |
+
past_key_values=encoder_outputs.past_key_values,
|
983 |
+
hidden_states=encoder_outputs.hidden_states,
|
984 |
+
attentions=encoder_outputs.attentions,
|
985 |
+
cross_attentions=encoder_outputs.cross_attentions,
|
986 |
+
)
|
987 |
+
|
988 |
+
|
989 |
+
class BertLMHeadModel(BertPreTrainedModel):
|
990 |
+
|
991 |
+
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
992 |
+
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
|
993 |
+
|
994 |
+
def __init__(self, config):
|
995 |
+
super().__init__(config)
|
996 |
+
|
997 |
+
self.bert = BertModel(config, add_pooling_layer=False)
|
998 |
+
self.cls = BertOnlyMLMHead(config)
|
999 |
+
|
1000 |
+
self.init_weights()
|
1001 |
+
|
1002 |
+
def get_output_embeddings(self):
|
1003 |
+
return self.cls.predictions.decoder
|
1004 |
+
|
1005 |
+
def set_output_embeddings(self, new_embeddings):
|
1006 |
+
self.cls.predictions.decoder = new_embeddings
|
1007 |
+
|
1008 |
+
def forward(
|
1009 |
+
self,
|
1010 |
+
input_ids=None,
|
1011 |
+
attention_mask=None,
|
1012 |
+
position_ids=None,
|
1013 |
+
head_mask=None,
|
1014 |
+
query_embeds=None,
|
1015 |
+
encoder_hidden_states=None,
|
1016 |
+
encoder_attention_mask=None,
|
1017 |
+
labels=None,
|
1018 |
+
past_key_values=None,
|
1019 |
+
use_cache=True,
|
1020 |
+
output_attentions=None,
|
1021 |
+
output_hidden_states=None,
|
1022 |
+
return_dict=None,
|
1023 |
+
return_logits=False,
|
1024 |
+
is_decoder=True,
|
1025 |
+
reduction="mean",
|
1026 |
+
):
|
1027 |
+
r"""
|
1028 |
+
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
1029 |
+
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
1030 |
+
the model is configured as a decoder.
|
1031 |
+
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
1032 |
+
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
1033 |
+
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
|
1034 |
+
- 1 for tokens that are **not masked**,
|
1035 |
+
- 0 for tokens that are **masked**.
|
1036 |
+
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
1037 |
+
Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
|
1038 |
+
``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
|
1039 |
+
ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
|
1040 |
+
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
1041 |
+
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
1042 |
+
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
|
1043 |
+
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
|
1044 |
+
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
|
1045 |
+
use_cache (:obj:`bool`, `optional`):
|
1046 |
+
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
|
1047 |
+
decoding (see :obj:`past_key_values`).
|
1048 |
+
Returns:
|
1049 |
+
Example::
|
1050 |
+
>>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
|
1051 |
+
>>> import torch
|
1052 |
+
>>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
|
1053 |
+
>>> config = BertConfig.from_pretrained("bert-base-cased")
|
1054 |
+
>>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
|
1055 |
+
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
|
1056 |
+
>>> outputs = model(**inputs)
|
1057 |
+
>>> prediction_logits = outputs.logits
|
1058 |
+
"""
|
1059 |
+
return_dict = (
|
1060 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
1061 |
+
)
|
1062 |
+
if labels is not None:
|
1063 |
+
use_cache = False
|
1064 |
+
if past_key_values is not None:
|
1065 |
+
query_embeds = None
|
1066 |
+
|
1067 |
+
outputs = self.bert(
|
1068 |
+
input_ids,
|
1069 |
+
attention_mask=attention_mask,
|
1070 |
+
position_ids=position_ids,
|
1071 |
+
head_mask=head_mask,
|
1072 |
+
query_embeds=query_embeds,
|
1073 |
+
encoder_hidden_states=encoder_hidden_states,
|
1074 |
+
encoder_attention_mask=encoder_attention_mask,
|
1075 |
+
past_key_values=past_key_values,
|
1076 |
+
use_cache=use_cache,
|
1077 |
+
output_attentions=output_attentions,
|
1078 |
+
output_hidden_states=output_hidden_states,
|
1079 |
+
return_dict=return_dict,
|
1080 |
+
is_decoder=is_decoder,
|
1081 |
+
)
|
1082 |
+
|
1083 |
+
sequence_output = outputs[0]
|
1084 |
+
if query_embeds is not None:
|
1085 |
+
sequence_output = outputs[0][:, query_embeds.shape[1] :, :]
|
1086 |
+
|
1087 |
+
prediction_scores = self.cls(sequence_output)
|
1088 |
+
|
1089 |
+
if return_logits:
|
1090 |
+
return prediction_scores[:, :-1, :].contiguous()
|
1091 |
+
|
1092 |
+
lm_loss = None
|
1093 |
+
if labels is not None:
|
1094 |
+
# we are doing next-token prediction; shift prediction scores and input ids by one
|
1095 |
+
shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
|
1096 |
+
labels = labels[:, 1:].contiguous()
|
1097 |
+
loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1)
|
1098 |
+
lm_loss = loss_fct(
|
1099 |
+
shifted_prediction_scores.view(-1, self.config.vocab_size),
|
1100 |
+
labels.view(-1),
|
1101 |
+
)
|
1102 |
+
if reduction == "none":
|
1103 |
+
lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1)
|
1104 |
+
|
1105 |
+
if not return_dict:
|
1106 |
+
output = (prediction_scores,) + outputs[2:]
|
1107 |
+
return ((lm_loss,) + output) if lm_loss is not None else output
|
1108 |
+
|
1109 |
+
return CausalLMOutputWithCrossAttentions(
|
1110 |
+
loss=lm_loss,
|
1111 |
+
logits=prediction_scores,
|
1112 |
+
past_key_values=outputs.past_key_values,
|
1113 |
+
hidden_states=outputs.hidden_states,
|
1114 |
+
attentions=outputs.attentions,
|
1115 |
+
cross_attentions=outputs.cross_attentions,
|
1116 |
+
)
|
1117 |
+
|
1118 |
+
def prepare_inputs_for_generation(
|
1119 |
+
self, input_ids, query_embeds, past=None, attention_mask=None, **model_kwargs
|
1120 |
+
):
|
1121 |
+
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
1122 |
+
if attention_mask is None:
|
1123 |
+
attention_mask = input_ids.new_ones(input_ids.shape)
|
1124 |
+
query_mask = input_ids.new_ones(query_embeds.shape[:-1])
|
1125 |
+
attention_mask = torch.cat([query_mask, attention_mask], dim=-1)
|
1126 |
+
|
1127 |
+
# cut decoder_input_ids if past is used
|
1128 |
+
if past is not None:
|
1129 |
+
input_ids = input_ids[:, -1:]
|
1130 |
+
|
1131 |
+
return {
|
1132 |
+
"input_ids": input_ids,
|
1133 |
+
"query_embeds": query_embeds,
|
1134 |
+
"attention_mask": attention_mask,
|
1135 |
+
"past_key_values": past,
|
1136 |
+
"encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
|
1137 |
+
"encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
|
1138 |
+
"is_decoder": True,
|
1139 |
+
}
|
1140 |
+
|
1141 |
+
def _reorder_cache(self, past, beam_idx):
|
1142 |
+
reordered_past = ()
|
1143 |
+
for layer_past in past:
|
1144 |
+
reordered_past += (
|
1145 |
+
tuple(
|
1146 |
+
past_state.index_select(0, beam_idx) for past_state in layer_past
|
1147 |
+
),
|
1148 |
+
)
|
1149 |
+
return reordered_past
|
1150 |
+
|
1151 |
+
|
1152 |
+
class BertForMaskedLM(BertPreTrainedModel):
|
1153 |
+
|
1154 |
+
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
1155 |
+
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
|
1156 |
+
|
1157 |
+
def __init__(self, config):
|
1158 |
+
super().__init__(config)
|
1159 |
+
|
1160 |
+
self.bert = BertModel(config, add_pooling_layer=False)
|
1161 |
+
self.cls = BertOnlyMLMHead(config)
|
1162 |
+
|
1163 |
+
self.init_weights()
|
1164 |
+
|
1165 |
+
def get_output_embeddings(self):
|
1166 |
+
return self.cls.predictions.decoder
|
1167 |
+
|
1168 |
+
def set_output_embeddings(self, new_embeddings):
|
1169 |
+
self.cls.predictions.decoder = new_embeddings
|
1170 |
+
|
1171 |
+
def forward(
|
1172 |
+
self,
|
1173 |
+
input_ids=None,
|
1174 |
+
attention_mask=None,
|
1175 |
+
position_ids=None,
|
1176 |
+
head_mask=None,
|
1177 |
+
query_embeds=None,
|
1178 |
+
encoder_hidden_states=None,
|
1179 |
+
encoder_attention_mask=None,
|
1180 |
+
labels=None,
|
1181 |
+
output_attentions=None,
|
1182 |
+
output_hidden_states=None,
|
1183 |
+
return_dict=None,
|
1184 |
+
return_logits=False,
|
1185 |
+
is_decoder=False,
|
1186 |
+
):
|
1187 |
+
r"""
|
1188 |
+
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
1189 |
+
Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
|
1190 |
+
config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
|
1191 |
+
(masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
|
1192 |
+
"""
|
1193 |
+
|
1194 |
+
return_dict = (
|
1195 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
1196 |
+
)
|
1197 |
+
|
1198 |
+
outputs = self.bert(
|
1199 |
+
input_ids,
|
1200 |
+
attention_mask=attention_mask,
|
1201 |
+
position_ids=position_ids,
|
1202 |
+
head_mask=head_mask,
|
1203 |
+
query_embeds=query_embeds,
|
1204 |
+
encoder_hidden_states=encoder_hidden_states,
|
1205 |
+
encoder_attention_mask=encoder_attention_mask,
|
1206 |
+
output_attentions=output_attentions,
|
1207 |
+
output_hidden_states=output_hidden_states,
|
1208 |
+
return_dict=return_dict,
|
1209 |
+
is_decoder=is_decoder,
|
1210 |
+
)
|
1211 |
+
|
1212 |
+
if query_embeds is not None:
|
1213 |
+
sequence_output = outputs[0][:, query_embeds.shape[1] :, :]
|
1214 |
+
prediction_scores = self.cls(sequence_output)
|
1215 |
+
|
1216 |
+
if return_logits:
|
1217 |
+
return prediction_scores
|
1218 |
+
|
1219 |
+
masked_lm_loss = None
|
1220 |
+
if labels is not None:
|
1221 |
+
loss_fct = CrossEntropyLoss() # -100 index = padding token
|
1222 |
+
masked_lm_loss = loss_fct(
|
1223 |
+
prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)
|
1224 |
+
)
|
1225 |
+
|
1226 |
+
if not return_dict:
|
1227 |
+
output = (prediction_scores,) + outputs[2:]
|
1228 |
+
return (
|
1229 |
+
((masked_lm_loss,) + output) if masked_lm_loss is not None else output
|
1230 |
+
)
|
1231 |
+
|
1232 |
+
return MaskedLMOutput(
|
1233 |
+
loss=masked_lm_loss,
|
1234 |
+
logits=prediction_scores,
|
1235 |
+
hidden_states=outputs.hidden_states,
|
1236 |
+
attentions=outputs.attentions,
|
1237 |
+
)
|
1238 |
+
|
1239 |
+
|
1240 |
+
def build_qformer(num_query_token, vision_width,
|
1241 |
+
qformer_hidden_dropout_prob=0.1,
|
1242 |
+
qformer_attention_probs_dropout_prob=0.1,
|
1243 |
+
qformer_drop_path_rate=0.,
|
1244 |
+
bert_type="bert-base-uncased"
|
1245 |
+
):
|
1246 |
+
|
1247 |
+
encoder_config = BertConfig.from_pretrained(bert_type)
|
1248 |
+
encoder_config.encoder_width = vision_width
|
1249 |
+
# insert cross-attention layer every other block
|
1250 |
+
encoder_config.add_cross_attention = True
|
1251 |
+
encoder_config.cross_attention_freq = 2
|
1252 |
+
encoder_config.query_length = num_query_token
|
1253 |
+
encoder_config.hidden_dropout_prob = qformer_hidden_dropout_prob
|
1254 |
+
encoder_config.attention_probs_dropout_prob = qformer_attention_probs_dropout_prob
|
1255 |
+
encoder_config.drop_path_list = [x.item() for x in torch.linspace(0, qformer_drop_path_rate, encoder_config.num_hidden_layers)]
|
1256 |
+
logger.info(f"Drop_path:{encoder_config.drop_path_list}")
|
1257 |
+
logger.info(encoder_config)
|
1258 |
+
Qformer = BertLMHeadModel(encoder_config)
|
1259 |
+
query_tokens = nn.Parameter(
|
1260 |
+
torch.zeros(1, num_query_token, encoder_config.hidden_size)
|
1261 |
+
)
|
1262 |
+
query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
|
1263 |
+
return Qformer, query_tokens
|
1264 |
+
|
modeling_special_token.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import transformers
|
2 |
+
DEFAULT_IMG_TOKEN = "[IMG]"
|
3 |
+
DEFAULT_IMG_END_TOKEN = "[/IMG]"
|
4 |
+
|
5 |
+
DEFAULT_IMAGE_TOKEN = "<image>"
|
6 |
+
DEFAULT_VIDEO_TOKEN = "[VIDEO]"
|
7 |
+
|
8 |
+
IMG_TOKEN = "[<IMG_PLH>]"
|
9 |
+
VID_TOKEN = "[<VID_PLH>]"
|
10 |
+
|
11 |
+
BOX_START = '<box_begin>'
|
12 |
+
ATBOXES_PLACEHOLDER = '<box_begin><boxes>'
|
13 |
+
BOXES_PLACEHOLDER = '<boxes>'
|
14 |
+
EXPR_PLACEHOLDER = '<expr>'
|
15 |
+
QUESTION_PLACEHOLDER = '<question>'
|
16 |
+
TIME_START = '<time_begin>'
|
17 |
+
TIME_PLACEHOLDER = '<temp>'
|
18 |
+
ATTEMP_PLACEHOLDER = TIME_START + TIME_PLACEHOLDER
|
19 |
+
TRACK_START='<track_begin>'
|
20 |
+
TRACK_PLACEHOLDER = '<tracking>'
|
21 |
+
TRACK_START_BOX = '<track_box>'
|
22 |
+
ATTRACK_PLACEHOLDER = TRACK_START + TRACK_PLACEHOLDER
|
23 |
+
need_template_list = ['REC', 'flickr', 'tracking', 'tracking2', 'tracking3', 'tracking4']
|
24 |
+
|
25 |
+
load_image_list = ['image', 'REC', 'flickr']
|
26 |
+
load_video_list = ['video', 'TVG', 'tracking', 'tracking2','tracking3', 'tracking4', 'TVG+HL']
|
27 |
+
special_tokens = [BOX_START, TIME_START, TIME_PLACEHOLDER, BOXES_PLACEHOLDER, TRACK_START, TRACK_PLACEHOLDER, TRACK_START_BOX]
|
modeling_videochate.py
ADDED
@@ -0,0 +1,681 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io
|
2 |
+
import logging
|
3 |
+
import json
|
4 |
+
import torch
|
5 |
+
import torch.utils.checkpoint
|
6 |
+
from torch import nn
|
7 |
+
from torch.nn import MSELoss
|
8 |
+
from transformers.modeling_outputs import (
|
9 |
+
CausalLMOutputWithPast,
|
10 |
+
)
|
11 |
+
from typing import List, Optional, Tuple, Union
|
12 |
+
from transformers import LlamaForCausalLM
|
13 |
+
from transformers.modeling_outputs import (
|
14 |
+
CausalLMOutputWithPast,
|
15 |
+
)
|
16 |
+
|
17 |
+
from torch.cuda.amp import autocast as autocast
|
18 |
+
import torch.nn.functional as F
|
19 |
+
|
20 |
+
import numpy as np
|
21 |
+
from .modeling_vit import build_vit, MLP, PostProcess
|
22 |
+
|
23 |
+
from .modeling_qformer import build_qformer
|
24 |
+
from .modeling_base import BaseMLLM
|
25 |
+
|
26 |
+
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
27 |
+
logger = logging.getLogger(__name__)
|
28 |
+
|
29 |
+
import pycocotools.mask as mask_util
|
30 |
+
|
31 |
+
from .modeling_base import VID_TOKEN, IMG_TOKEN
|
32 |
+
|
33 |
+
class MultiModalLLM_PT(BaseMLLM):
|
34 |
+
def __init__(
|
35 |
+
self,
|
36 |
+
config,
|
37 |
+
_tokenizer=None
|
38 |
+
):
|
39 |
+
super().__init__(config=config, _tokenizer=_tokenizer)
|
40 |
+
self.use_clip = False
|
41 |
+
self.num_frames = 16
|
42 |
+
self.num_clips = 1
|
43 |
+
self.token_merge_len = 4
|
44 |
+
|
45 |
+
self.per_clip_frames = self.num_frames // self.num_clips
|
46 |
+
|
47 |
+
print(self.config)
|
48 |
+
self.merge_proj = nn.Linear(
|
49 |
+
self.qformer.config.hidden_size*self.token_merge_len, self.config.hidden_size
|
50 |
+
)
|
51 |
+
|
52 |
+
if config.build_decoder:
|
53 |
+
self.track_embed = MLP(self.config.hidden_size, self.config.hidden_size, 3 * 256, 2, dropout=0)
|
54 |
+
self.track_embed_decode2 = MLP(4096, 4096, 4, 2, dropout=0)
|
55 |
+
self.temporal_embed = MLP(self.config.hidden_size, self.config.hidden_size, 2, 2, dropout=0.3)
|
56 |
+
self.action_embed = MLP(self.config.hidden_size, self.config.hidden_size, 1, 2, dropout=0.3)
|
57 |
+
self.postprocess = PostProcess()
|
58 |
+
self.track_token = nn.Parameter(torch.randn((1, 1, 4096)))
|
59 |
+
self.temporal_token = nn.Parameter(torch.randn((1, 1, 4096)))
|
60 |
+
self.box_token = nn.Parameter(torch.randn((1, 1, 4096)))
|
61 |
+
|
62 |
+
|
63 |
+
def forward(
|
64 |
+
self,
|
65 |
+
input_ids: torch.LongTensor = None,
|
66 |
+
attention_mask: Optional[torch.Tensor] = None,
|
67 |
+
labels: Optional[torch.LongTensor] = None,
|
68 |
+
image: Optional[torch.Tensor] = None,
|
69 |
+
video: Optional[torch.Tensor] = None,
|
70 |
+
instruction = None,
|
71 |
+
video_idx = None,
|
72 |
+
image_idx = None,
|
73 |
+
output_boxes = None, # REC
|
74 |
+
input_boxes = None, # tracking inputs
|
75 |
+
text_input = None,
|
76 |
+
video_info = None,
|
77 |
+
temporal_labels = None,
|
78 |
+
gt_masks = None,
|
79 |
+
sam_images = None,
|
80 |
+
size_hw = None,
|
81 |
+
path = None,
|
82 |
+
mask_path = None,
|
83 |
+
tvg_inputs = None,
|
84 |
+
tvg_targets = None,
|
85 |
+
):
|
86 |
+
if text_input is not None:
|
87 |
+
time_instructions = self.get_clip_time_instruct(text_input)
|
88 |
+
else:
|
89 |
+
time_instructions = None
|
90 |
+
text_embeds = self.pad_text_embeds(input_ids=input_ids, image=image, video=video, return_visual=False,
|
91 |
+
video_idx=video_idx, image_idx=image_idx, instruction = instruction,
|
92 |
+
output_boxes = output_boxes, input_boxes=input_boxes, time_instructions = time_instructions)
|
93 |
+
outputs = self.lm(
|
94 |
+
inputs_embeds=text_embeds,
|
95 |
+
attention_mask=attention_mask,
|
96 |
+
labels=labels,
|
97 |
+
output_hidden_states=True,
|
98 |
+
return_dict=True,
|
99 |
+
)
|
100 |
+
loss = outputs.loss
|
101 |
+
logger.info(f'llm loss:{loss}')
|
102 |
+
|
103 |
+
if output_boxes is not None and self.use_bbox_loss:
|
104 |
+
last_hidden_states = outputs.hidden_states[-1]
|
105 |
+
pred_locs = []
|
106 |
+
for idx in range(last_hidden_states.shape[0]):
|
107 |
+
loc_positions = ( (input_ids[idx].flatten() == self.tokenizer.box_token) ).nonzero().flatten()
|
108 |
+
selected_hidden_states = last_hidden_states[idx][loc_positions]
|
109 |
+
pred_locs.append(self.loc_decoder(selected_hidden_states))
|
110 |
+
box_loss = self.box_loss(pred_locs, output_boxes)
|
111 |
+
logger.info(f'box loss:{box_loss}')
|
112 |
+
loss += box_loss
|
113 |
+
|
114 |
+
if (gt_masks is not None or input_boxes is not None) and self.use_mask_loss:
|
115 |
+
last_hidden_states = outputs.hidden_states[-1]
|
116 |
+
pred_masks = []
|
117 |
+
sam_losses = []
|
118 |
+
box_losses = []
|
119 |
+
for idx in range(last_hidden_states.shape[0]):
|
120 |
+
loc_positions = ( (input_ids[idx].flatten() == self.tokenizer.track_token) ).nonzero().flatten()
|
121 |
+
selected_hidden_states = last_hidden_states[idx][loc_positions]
|
122 |
+
embed_sam_boxes = self.track_embed(selected_hidden_states).reshape(1, 3, 256)
|
123 |
+
inference_state = self.sam.init_state_images(sam_images, size_hw[idx][0], size_hw[idx][1])
|
124 |
+
|
125 |
+
if input_boxes is not None:
|
126 |
+
gt_embeds = self.sam.get_prompt_embeding(inference_state, None, None, False, input_boxes[idx], device = text_embeds.device)
|
127 |
+
else:
|
128 |
+
input_boxes = self.find_boundaries_torch(gt_masks.squeeze(0)[:,:,:1].squeeze(2).cpu()).to(text_embeds.device)
|
129 |
+
gt_embeds = self.sam.get_prompt_embeding(inference_state, None, None, False, input_boxes, device = text_embeds.device)
|
130 |
+
pred_locs = [self.track_embed_decode2((selected_hidden_states))[0]]
|
131 |
+
target_boxes = [input_boxes[idx]]
|
132 |
+
|
133 |
+
src_boxes = pred_locs
|
134 |
+
loss_bbox = self.box_loss2(src_boxes, target_boxes)
|
135 |
+
|
136 |
+
loss_bbox = self.masked_loss(loss_bbox, 0)
|
137 |
+
box_losses.append(loss_bbox)
|
138 |
+
sam_losses.append( F.l1_loss(embed_sam_boxes, gt_embeds))
|
139 |
+
|
140 |
+
logger.info(f'refering sam loss:{sam_losses}')
|
141 |
+
sam_losses = torch.stack(sam_losses)
|
142 |
+
box_losses = torch.stack(box_losses)
|
143 |
+
loss += torch.mean(sam_losses)
|
144 |
+
loss += torch.mean(box_losses)
|
145 |
+
|
146 |
+
if tvg_inputs is not None and self.use_temporal_loss:
|
147 |
+
last_hidden_states = outputs.hidden_states[-1] # [bsz,1024, 4096]
|
148 |
+
last_hidden_states = last_hidden_states.view(-1, last_hidden_states.size(-1)) # [bsz*1024, 4096]
|
149 |
+
loc_positions = (input_ids.flatten()==self.tokenizer.temp_token).nonzero().flatten() # [bsz]
|
150 |
+
prompt_token = last_hidden_states[loc_positions]
|
151 |
+
prompt_token = prompt_token.view(input_ids.shape[0], -1 ,prompt_token.shape[-1]) # [bsz, 1, 4096]
|
152 |
+
|
153 |
+
|
154 |
+
cg_outputs = self.cg_model(**tvg_inputs, targets=tvg_targets, prompt_token=prompt_token)
|
155 |
+
loss_dict = self.cg_criterion(cg_outputs, tvg_targets)
|
156 |
+
weight_dict = self.cg_criterion.weight_dict
|
157 |
+
tvg_loss = 0.05*sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)
|
158 |
+
logger.info(f'tvg_loss:{tvg_loss}')
|
159 |
+
loss += tvg_loss
|
160 |
+
|
161 |
+
|
162 |
+
logger.info(f'all loss:{loss}')
|
163 |
+
return CausalLMOutputWithPast(
|
164 |
+
loss=loss,
|
165 |
+
logits=outputs.logits,
|
166 |
+
past_key_values=outputs.past_key_values,
|
167 |
+
hidden_states=outputs.hidden_states,
|
168 |
+
attentions=outputs.attentions,
|
169 |
+
)
|
170 |
+
|
171 |
+
def pad_text_embeds(
|
172 |
+
self,
|
173 |
+
input_ids: torch.LongTensor = None,
|
174 |
+
image: Optional[torch.Tensor] = None,
|
175 |
+
video: Optional[torch.Tensor] = None,
|
176 |
+
image_idx = None,
|
177 |
+
video_idx = None,
|
178 |
+
return_visual: bool = False,
|
179 |
+
instruction = None,
|
180 |
+
output_boxes = None, # boxes for REC
|
181 |
+
input_boxes = None, # boxes for tracking
|
182 |
+
time_instructions = None,
|
183 |
+
):
|
184 |
+
text_embeds = self.lm.get_input_embeddings()(input_ids.long()).detach()
|
185 |
+
if input_boxes is not None:
|
186 |
+
input_boxes = input_boxes[0].to(dtype=text_embeds.dtype)
|
187 |
+
|
188 |
+
boxes_emb = self.loc_encoder(input_boxes)
|
189 |
+
boxes_emb = boxes_emb.view(-1, 4096)
|
190 |
+
|
191 |
+
text_embeds[input_ids == torch.full_like(input_ids, self.tokenizer.track_box_token)] = text_embeds[input_ids == torch.full_like(input_ids, self.tokenizer.track_box_token)] * 0 + boxes_emb.to(text_embeds.device)
|
192 |
+
logger.info(f'embedings:{text_embeds[input_ids == torch.full_like(input_ids, self.tokenizer.track_box_token)].shape}')
|
193 |
+
visual = None
|
194 |
+
visual_idx = None
|
195 |
+
|
196 |
+
if image is not None:
|
197 |
+
|
198 |
+
B, T, C, H, W = image.shape
|
199 |
+
image = image.permute(0, 2, 1, 3, 4)
|
200 |
+
|
201 |
+
instruction = None
|
202 |
+
|
203 |
+
prompt_image_embeds = self.encode_vision(image, instruction)
|
204 |
+
|
205 |
+
visual = prompt_image_embeds
|
206 |
+
|
207 |
+
prompt_image_embeds = self.project_up(prompt_image_embeds) # 768 -> 4096
|
208 |
+
prompt_image_embeds = prompt_image_embeds.view(-1, prompt_image_embeds.shape[-1])
|
209 |
+
|
210 |
+
visual_idx = image_idx
|
211 |
+
|
212 |
+
prompt_image_embeds = prompt_image_embeds.to(dtype=text_embeds.dtype)
|
213 |
+
|
214 |
+
text_embeds[image_idx == 1] = torch.zeros_like(text_embeds[image_idx == 1]) + prompt_image_embeds.to(text_embeds.device)
|
215 |
+
|
216 |
+
|
217 |
+
elif video is not None:
|
218 |
+
if len(video.shape) == 5:
|
219 |
+
B, T, C, H, W = video.shape
|
220 |
+
N = 1
|
221 |
+
if self.use_clip:
|
222 |
+
video = video.reshape(B*self.num_clips, T//self.num_clips, C, H, W) # [16, 8, 3, 224, 224]
|
223 |
+
else:
|
224 |
+
B, N, T, C, H, W = video.shape
|
225 |
+
|
226 |
+
video = video.permute(0,2,1,3,4) #
|
227 |
+
|
228 |
+
|
229 |
+
prompt_video_embeds = self.encode_vision(video, instruction=time_instructions) # [2, 96, 768]
|
230 |
+
if self.use_clip:
|
231 |
+
prompt_video_embeds = prompt_video_embeds.reshape(B,-1,prompt_video_embeds.shape[-1]) # [2,8*96,768]
|
232 |
+
batch_size, img_len, token_dim = prompt_video_embeds.shape
|
233 |
+
prompt_video_embeds = prompt_video_embeds.view(batch_size, img_len // self.token_merge_len, self.token_merge_len * token_dim) # [B, 768//4, 4*768] = [2, 192, 3072]
|
234 |
+
prompt_video_embeds = self.merge_proj(prompt_video_embeds) # [2, 192, 4096]
|
235 |
+
prompt_video_embeds = prompt_video_embeds.view(-1, prompt_video_embeds.shape[-1]) # [2*192, 4096]
|
236 |
+
|
237 |
+
else:
|
238 |
+
prompt_video_embeds = self.project_up(prompt_video_embeds) # [2, 96, 4096]
|
239 |
+
|
240 |
+
prompt_video_embeds = prompt_video_embeds.view(-1, prompt_video_embeds.shape[-1])
|
241 |
+
visual_idx = video_idx
|
242 |
+
|
243 |
+
|
244 |
+
text_embeds[video_idx == 1] = torch.zeros_like(text_embeds[video_idx == 1]) + prompt_video_embeds.to(text_embeds.device).to(text_embeds.dtype)
|
245 |
+
|
246 |
+
else:
|
247 |
+
logger.warn(f"don't get visual input, input_ids: {input_ids}")
|
248 |
+
|
249 |
+
|
250 |
+
for idx, text_embed in enumerate(text_embeds):
|
251 |
+
if text_embeds[idx][input_ids[idx].flatten() == self.tokenizer.box_token].shape[0] != 0:
|
252 |
+
text_embeds[idx][input_ids[idx].flatten() == self.tokenizer.box_token] = torch.zeros_like(text_embeds[idx][input_ids[idx] == self.tokenizer.box_token]) + torch.cat([self.box_token.squeeze(0)] * (text_embeds[idx][input_ids[idx] == self.tokenizer.box_token]).shape[0]).to(text_embeds.dtype)
|
253 |
+
if text_embeds[idx][input_ids[idx].flatten() == self.tokenizer.temp_token].shape[0] != 0:
|
254 |
+
text_embeds[idx][input_ids[idx].flatten() == self.tokenizer.temp_token] = torch.zeros_like(text_embeds[idx][input_ids[idx] == self.tokenizer.temp_token]) + self.temporal_token
|
255 |
+
if text_embeds[idx][input_ids[idx].flatten() == self.tokenizer.track_token].shape[0] != 0:
|
256 |
+
text_embeds[idx][input_ids[idx].flatten() == self.tokenizer.track_token] = torch.zeros_like(text_embeds[idx][input_ids[idx] == self.tokenizer.track_token]) + self.track_token
|
257 |
+
|
258 |
+
if return_visual:
|
259 |
+
return text_embeds, visual, visual_idx
|
260 |
+
|
261 |
+
return text_embeds
|
262 |
+
|
263 |
+
|
264 |
+
|
265 |
+
def temporal_decode(self, temporal_embedding):
|
266 |
+
pred_sted = self.temporal_embed(temporal_embedding)
|
267 |
+
pred_actioness = self.action_embed(temporal_embedding)
|
268 |
+
return pred_sted, pred_actioness
|
269 |
+
|
270 |
+
|
271 |
+
def box_loss2(self, src_boxes, target_boxes):
|
272 |
+
src_boxes = torch.cat(src_boxes)
|
273 |
+
target_boxes = torch.cat(target_boxes)
|
274 |
+
|
275 |
+
loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none')
|
276 |
+
loss_bbox = self.masked_loss(loss_bbox, 0)
|
277 |
+
mask = (src_boxes[2:] >= src_boxes[:2]).all(-1)
|
278 |
+
src_boxes = src_boxes[mask]
|
279 |
+
target_boxes = target_boxes[mask]
|
280 |
+
|
281 |
+
return loss_bbox
|
282 |
+
|
283 |
+
def box_loss(self, src_boxes, target_boxes):
|
284 |
+
src_boxes = torch.cat(src_boxes)
|
285 |
+
target_boxes = torch.cat(target_boxes)
|
286 |
+
|
287 |
+
loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none')
|
288 |
+
loss_bbox = self.masked_loss(loss_bbox, 0)
|
289 |
+
mask = (src_boxes[:, 2:] >= src_boxes[ :, :2]).all(-1)
|
290 |
+
src_boxes = src_boxes[mask]
|
291 |
+
target_boxes = target_boxes[mask]
|
292 |
+
|
293 |
+
if src_boxes.shape[0] > 0:
|
294 |
+
loss_giou = 1 - torch.diag(generalized_box_iou(
|
295 |
+
src_boxes,
|
296 |
+
target_boxes))
|
297 |
+
loss_giou = self.masked_loss(loss_giou, 0)
|
298 |
+
else:
|
299 |
+
loss_giou = torch.tensor(2, dtype=src_boxes.dtype)
|
300 |
+
iou, union = box_iou(src_boxes, target_boxes)
|
301 |
+
|
302 |
+
return loss_bbox * 2 + loss_giou / 5
|
303 |
+
|
304 |
+
def find_boundaries_torch(self, mask):
|
305 |
+
|
306 |
+
from skimage.segmentation import find_boundaries
|
307 |
+
mask_np = mask.to(torch.bool).numpy()
|
308 |
+
boundaries = find_boundaries(mask_np, mode='outer')
|
309 |
+
boundary_points = np.argwhere(boundaries)
|
310 |
+
if boundary_points.size == 0:
|
311 |
+
return torch.tensor([-1, -1, -1, -1], dtype = torch.bfloat16)
|
312 |
+
h0, w0 = boundary_points.min(axis=0)
|
313 |
+
h1, w1 = boundary_points.max(axis=0)
|
314 |
+
return torch.tensor([w0 / mask.shape[1], h0 / mask.shape[0], w1 / mask.shape[1], h1 / mask.shape[0]], dtype = torch.bfloat16)
|
315 |
+
|
316 |
+
|
317 |
+
def sam_loss(self, sam_outputs, gt_masks):
|
318 |
+
bound1 = self.find_boundaries_torch(gt_masks[:,:,:1].squeeze(2).cpu())
|
319 |
+
bound2 = self.find_boundaries_torch(sam_outputs[:,:,:1].squeeze(2).cpu())
|
320 |
+
|
321 |
+
lossl1 = F.l1_loss(bound1, bound2, reduction='none')
|
322 |
+
lossl1 = self.masked_loss(lossl1, 0)
|
323 |
+
|
324 |
+
loss_iou = self.iou_loss(sam_outputs, gt_masks)
|
325 |
+
loss_dice = self.dice_loss(sam_outputs, gt_masks)
|
326 |
+
|
327 |
+
# print(f'mask loss:{loss_iou, loss_dice}')
|
328 |
+
return loss_iou + loss_dice + lossl1
|
329 |
+
|
330 |
+
def masked_loss(self, loss, n):
|
331 |
+
mask = torch.ones_like(loss)
|
332 |
+
# mask[-n:] = 1e-10
|
333 |
+
loss = (loss*mask).sum()/(mask.sum())
|
334 |
+
return loss
|
335 |
+
|
336 |
+
def encode_vision(
|
337 |
+
self,
|
338 |
+
image,
|
339 |
+
instruction
|
340 |
+
):
|
341 |
+
device = image.device
|
342 |
+
B = image.shape[0]
|
343 |
+
T = image.shape[2]
|
344 |
+
use_image = True if T == 1 else False
|
345 |
+
image_embeds = self.vision_encoder(image, use_image=use_image)
|
346 |
+
C = image_embeds.shape[-1]
|
347 |
+
image_embeds = image_embeds.reshape(B, -1, C)
|
348 |
+
image_embeds = self.vision_layernorm(image_embeds).to(device) # [B, T*L, C]
|
349 |
+
|
350 |
+
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device)
|
351 |
+
if self.extra_num_query_token > 0:
|
352 |
+
query_tokens = torch.cat([self.query_tokens, self.extra_query_tokens], dim=1)
|
353 |
+
query_tokens = query_tokens.expand(image_embeds.shape[0], -1, -1)
|
354 |
+
if instruction is not None:
|
355 |
+
text_Qformer = self.qformer_tokenizer(
|
356 |
+
instruction,
|
357 |
+
padding='longest',
|
358 |
+
truncation=True,
|
359 |
+
max_length=512,
|
360 |
+
return_tensors="pt",
|
361 |
+
).to(image_embeds.device)
|
362 |
+
query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(image_embeds.device)
|
363 |
+
Qformer_atts = torch.cat([query_atts, text_Qformer.attention_mask], dim=1)
|
364 |
+
query_output = self.qformer.bert(
|
365 |
+
text_Qformer.input_ids,
|
366 |
+
attention_mask=Qformer_atts,
|
367 |
+
query_embeds=query_tokens,
|
368 |
+
encoder_hidden_states=image_embeds,
|
369 |
+
encoder_attention_mask=image_atts,
|
370 |
+
return_dict=True,
|
371 |
+
)
|
372 |
+
else:
|
373 |
+
query_output = self.qformer.bert(
|
374 |
+
query_embeds=query_tokens,
|
375 |
+
encoder_hidden_states=image_embeds,
|
376 |
+
encoder_attention_mask=image_atts,
|
377 |
+
return_dict=True,
|
378 |
+
)
|
379 |
+
|
380 |
+
return query_output.last_hidden_state[:, :query_tokens.size(1), :]
|
381 |
+
|
382 |
+
def generate_caption(
|
383 |
+
self,
|
384 |
+
input_ids,
|
385 |
+
attention_mask,
|
386 |
+
image_idx = None,
|
387 |
+
video_idx = None,
|
388 |
+
image: Optional[torch.Tensor] = None,
|
389 |
+
video: Optional[torch.Tensor] = None,
|
390 |
+
num_beams=1,
|
391 |
+
max_new_tokens=200,
|
392 |
+
do_sample=True,
|
393 |
+
top_p=0.9,
|
394 |
+
top_k=None,
|
395 |
+
temperature=1.0,
|
396 |
+
length_penalty=1,
|
397 |
+
repetition_penalty=1.0,
|
398 |
+
):
|
399 |
+
text_embeds = self.pad_text_embeds(input_ids=input_ids, image=image, video=video, image_idx=image_idx, video_idx=video_idx)
|
400 |
+
outputs = self.lm.generate(
|
401 |
+
inputs_embeds=text_embeds,
|
402 |
+
attention_mask=attention_mask,
|
403 |
+
num_beams=num_beams,
|
404 |
+
max_new_tokens=max_new_tokens,
|
405 |
+
do_sample=do_sample,
|
406 |
+
min_length=1,
|
407 |
+
top_p=top_p,
|
408 |
+
top_k=top_k,
|
409 |
+
temperature=temperature,
|
410 |
+
length_penalty=length_penalty,
|
411 |
+
repetition_penalty=repetition_penalty,
|
412 |
+
)
|
413 |
+
|
414 |
+
return outputs
|
415 |
+
|
416 |
+
def generate_caption_bbox(
|
417 |
+
self,
|
418 |
+
input_ids,
|
419 |
+
attention_mask,
|
420 |
+
labels,
|
421 |
+
image_idx = None,
|
422 |
+
video_idx = None,
|
423 |
+
image: Optional[torch.Tensor] = None,
|
424 |
+
video: Optional[torch.Tensor] = None,
|
425 |
+
num_beams=1,
|
426 |
+
max_new_tokens=200,
|
427 |
+
do_sample=True,
|
428 |
+
top_p=0.9,
|
429 |
+
top_k=None,
|
430 |
+
temperature=0.9,
|
431 |
+
length_penalty=1,
|
432 |
+
repetition_penalty=1.0,
|
433 |
+
):
|
434 |
+
text_embeds = self.pad_text_embeds(input_ids=input_ids, image=image, video=video, image_idx=image_idx, video_idx=video_idx)
|
435 |
+
outputs = self.lm.generate(
|
436 |
+
inputs_embeds=text_embeds,
|
437 |
+
attention_mask=attention_mask,
|
438 |
+
num_beams=num_beams,
|
439 |
+
max_new_tokens=max_new_tokens,
|
440 |
+
do_sample=do_sample,
|
441 |
+
min_length=1,
|
442 |
+
top_p=top_p,
|
443 |
+
top_k=top_k,
|
444 |
+
temperature=temperature,
|
445 |
+
length_penalty=length_penalty,
|
446 |
+
repetition_penalty=repetition_penalty,
|
447 |
+
)
|
448 |
+
decoded_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
449 |
+
# torch.save({'text':decoded_text, 'output':{outputs}}, 'tmp.pth')
|
450 |
+
# print(decoded_text)
|
451 |
+
return outputs
|
452 |
+
|
453 |
+
def generate_temporal(self,
|
454 |
+
input_ids: torch.LongTensor = None,
|
455 |
+
attention_mask: Optional[torch.Tensor] = None,
|
456 |
+
labels: Optional[torch.LongTensor] = None,
|
457 |
+
image: Optional[torch.Tensor] = None,
|
458 |
+
video: Optional[torch.Tensor] = None,
|
459 |
+
instruction = None,
|
460 |
+
video_idx = None,
|
461 |
+
image_idx = None,
|
462 |
+
boxes = None,
|
463 |
+
text_input = None,
|
464 |
+
video_info = None,
|
465 |
+
temporal_labels = None):
|
466 |
+
|
467 |
+
if text_input is not None:
|
468 |
+
time_instructions = self.get_clip_time_instruct(text_input)
|
469 |
+
else:
|
470 |
+
time_instructions = None
|
471 |
+
text_embeds = self.pad_text_embeds(input_ids=input_ids, image=image, video=video, return_visual=False,
|
472 |
+
video_idx=video_idx, image_idx=image_idx, instruction = instruction,
|
473 |
+
boxes = boxes, time_instructions = time_instructions)
|
474 |
+
|
475 |
+
# TODO
|
476 |
+
outputs = self.lm(
|
477 |
+
inputs_embeds=text_embeds,
|
478 |
+
attention_mask=attention_mask,
|
479 |
+
labels=labels,
|
480 |
+
output_hidden_states=True,
|
481 |
+
return_dict=True,
|
482 |
+
)
|
483 |
+
|
484 |
+
if temporal_labels is not None:
|
485 |
+
start_sec = temporal_labels["start_sec"]
|
486 |
+
end_sec = temporal_labels["end_sec"]
|
487 |
+
fps = video_info['fps']
|
488 |
+
frame_indices = video_info['frame_indices']
|
489 |
+
|
490 |
+
last_hidden_states = outputs.hidden_states[-1] # [2,1024, 4096]
|
491 |
+
last_hidden_states = last_hidden_states.view(-1, last_hidden_states.size(-1)) # [2048, 4096]
|
492 |
+
loc_positions = (input_ids.flatten()==self.tokenizer.temp_place_ids).nonzero().flatten() #
|
493 |
+
selected_hidden_states = last_hidden_states[loc_positions]
|
494 |
+
selected_hidden_states = selected_hidden_states.view(input_ids.shape[0], -1 ,selected_hidden_states.shape[-1]) # [2, 64, 4096]
|
495 |
+
|
496 |
+
# just for debug
|
497 |
+
|
498 |
+
# vis_embed = vis_embed[:,:64,:]
|
499 |
+
|
500 |
+
pred_sted, pred_actionness = self.temporal_decode(selected_hidden_states) # [2,64,2] [2,64,1]
|
501 |
+
|
502 |
+
pred_sted = self.postprocess(pred_sted, frame_indices)
|
503 |
+
pred_sec_s = pred_sted[0][0] / fps[0][0].item()
|
504 |
+
pred_sec_e = pred_sted[0][1] / fps[0][0].item()
|
505 |
+
|
506 |
+
output_file = "predictions2.jsonl"
|
507 |
+
prediction = {"pred_sec_s": round(pred_sec_s, 1), "pred_sec_e": round(pred_sec_e, 1), "start_sec":float(start_sec[0]), "end_sec": float(end_sec[0])}
|
508 |
+
|
509 |
+
with open(output_file, 'a') as f:
|
510 |
+
json.dump(prediction, f)
|
511 |
+
f.write('\n')
|
512 |
+
|
513 |
+
return outputs
|
514 |
+
|
515 |
+
def generate_seg(self, input_ids, attention_mask, labels, image, image_idx, video, video_idx, input_boxes, size_hw, sam_images):
|
516 |
+
device = input_ids.device
|
517 |
+
prompt = input_ids
|
518 |
+
l_prompt = len(input_ids)
|
519 |
+
temperature = 1e-5
|
520 |
+
max_new_tokens = 20
|
521 |
+
guide_w = 5
|
522 |
+
stop_str = '</s>'
|
523 |
+
bbox = []
|
524 |
+
output_ids = list(input_ids[0])
|
525 |
+
text_embeds = self.pad_text_embeds(input_ids=input_ids, image=image, video=video, image_idx=image_idx, video_idx=video_idx, return_visual=False,
|
526 |
+
instruction = None, output_boxes=None, input_boxes=input_boxes)
|
527 |
+
for i in range(max_new_tokens):
|
528 |
+
if i == 0:
|
529 |
+
outputs = self.lm(
|
530 |
+
inputs_embeds=text_embeds,
|
531 |
+
attention_mask=attention_mask,
|
532 |
+
output_hidden_states=True,
|
533 |
+
return_dict=True,
|
534 |
+
)
|
535 |
+
logits = outputs.logits
|
536 |
+
past_key_values = outputs.past_key_values
|
537 |
+
else:
|
538 |
+
attention_mask = torch.ones(1, past_key_values[0][0].shape[-2] + 1, device=device)
|
539 |
+
last_text_embeds = self.lm.get_input_embeddings()(torch.tensor(output_ids[-1], device=device).long()).detach().unsqueeze(0)
|
540 |
+
last_text_embeds = last_text_embeds.unsqueeze(0)
|
541 |
+
|
542 |
+
out = self.lm(
|
543 |
+
input_ids=None,
|
544 |
+
use_cache=True,
|
545 |
+
attention_mask=attention_mask,
|
546 |
+
output_hidden_states=True,
|
547 |
+
inputs_embeds=last_text_embeds,
|
548 |
+
past_key_values=past_key_values,
|
549 |
+
)
|
550 |
+
logits = out.logits
|
551 |
+
past_key_values = out.past_key_values
|
552 |
+
if logits is not None:
|
553 |
+
last_token_logits = logits[0][-1]
|
554 |
+
if temperature < 1e-4:
|
555 |
+
token = int(torch.argmax(last_token_logits))
|
556 |
+
else:
|
557 |
+
probs = torch.softmax(last_token_logits / temperature, dim=-1)
|
558 |
+
token = int(torch.multinomial(probs, num_samples=1))
|
559 |
+
output_ids.append(token)
|
560 |
+
ret = self.tokenizer.decode(token)
|
561 |
+
if ret == '<box_begin>':
|
562 |
+
attention_mask = torch.ones(1, past_key_values[0][0].shape[-2] + 1, device=device)
|
563 |
+
bbox_embeds = self.box_token.bfloat16()
|
564 |
+
out = self.lm(
|
565 |
+
inputs_embeds=bbox_embeds,
|
566 |
+
use_cache=True,
|
567 |
+
attention_mask=attention_mask,
|
568 |
+
output_hidden_states=True,
|
569 |
+
past_key_values=past_key_values
|
570 |
+
)
|
571 |
+
last_hidden_states = out.hidden_states[-1]
|
572 |
+
selected_hidden_states = last_hidden_states[0][0]
|
573 |
+
bbox.append(self.loc_decoder(selected_hidden_states))
|
574 |
+
last_token_logits = logits[0][-1]
|
575 |
+
if temperature < 1e-4:
|
576 |
+
token = int(torch.argmax(last_token_logits))
|
577 |
+
else:
|
578 |
+
probs = torch.softmax(last_token_logits / temperature, dim=-1)
|
579 |
+
token = int(torch.multinomial(probs, num_samples=1))
|
580 |
+
if ret == '<track_begin>':
|
581 |
+
attention_mask = torch.ones(1, past_key_values[0][0].shape[-2] + 1, device=device)
|
582 |
+
tracking_embeds = self.track_token
|
583 |
+
out = self.lm(
|
584 |
+
inputs_embeds=tracking_embeds,
|
585 |
+
use_cache=True,
|
586 |
+
attention_mask=attention_mask,
|
587 |
+
output_hidden_states=True,
|
588 |
+
past_key_values=past_key_values
|
589 |
+
)
|
590 |
+
last_hidden_states = out.hidden_states[-1]
|
591 |
+
selected_hidden_states = last_hidden_states[0][0].to(dtype = torch.bfloat16)
|
592 |
+
|
593 |
+
embed_sam_boxes = self.track_embed(selected_hidden_states).reshape(1, 3, 256)
|
594 |
+
|
595 |
+
inference_state = self.sam.init_state_images(sam_images, size_hw[0][0], size_hw[0][1])
|
596 |
+
gt_embeds = self.sam.get_prompt_embeding(inference_state, None, None, False, input_boxes[0].cuda(), device = text_embeds.device)
|
597 |
+
ann_frame_idx = 0
|
598 |
+
ann_obj_id = 0
|
599 |
+
box = np.array([0, 0, 0, 0], dtype=np.float32)
|
600 |
+
_, out_obj_ids, out_mask_logits = self.sam.add_new_box_embeding(
|
601 |
+
inference_state=inference_state,
|
602 |
+
frame_idx=ann_frame_idx,
|
603 |
+
obj_id=ann_obj_id,
|
604 |
+
box=box,
|
605 |
+
box_embeding=embed_sam_boxes,
|
606 |
+
)
|
607 |
+
video_segments = {} # video_segments contains the per-frame segmentation results
|
608 |
+
for out_frame_idx, out_obj_ids, out_mask_logits in self.sam.propagate_in_video(inference_state):
|
609 |
+
video_segments[out_frame_idx] = {
|
610 |
+
out_obj_id: (out_mask_logits[i] > 0.0)
|
611 |
+
for i, out_obj_id in enumerate(out_obj_ids)
|
612 |
+
}
|
613 |
+
video_segments = [video_segments[tt][0] for tt in video_segments]
|
614 |
+
# bbox = model.find_boundaries_torch(video_segments[0].squeeze(0).cpu())
|
615 |
+
# return ret, [], video_segments
|
616 |
+
|
617 |
+
if (ret == '</s>'):
|
618 |
+
break
|
619 |
+
ret = self.tokenizer.decode(output_ids)
|
620 |
+
del past_key_values
|
621 |
+
return ret, bbox, video_segments
|
622 |
+
|
623 |
+
def generate_answer(self, tokenizer, instruction, msg, user_prompt, media_type="video",video_tensor=None, image_tensor=None, answer_prompt=None, chat_history=[],return_history=False, debug=False, generation_config={}):
|
624 |
+
input_ids, attention_masks, labels = [], [], []
|
625 |
+
|
626 |
+
conversation = ""
|
627 |
+
if instruction:
|
628 |
+
conversation += instruction
|
629 |
+
conversation += (
|
630 |
+
"[INST]" + " "
|
631 |
+
)
|
632 |
+
|
633 |
+
if media_type == 'image':
|
634 |
+
conversation +=( "<Image>" + IMG_TOKEN + "</Image>")
|
635 |
+
else:
|
636 |
+
conversation += ("<Video>" + VID_TOKEN + "</Video>")
|
637 |
+
|
638 |
+
conversation += ( msg.rstrip() + "[/INST]")
|
639 |
+
|
640 |
+
for q,a in chat_history:
|
641 |
+
conversation += (" [INST] " + q + " [/INST]")
|
642 |
+
conversation += (a + "</s>")
|
643 |
+
|
644 |
+
conversation += (" [INST] " + user_prompt + " [/INST]")
|
645 |
+
conversation += ("")
|
646 |
+
if answer_prompt:
|
647 |
+
conversation += ("Best Option: (")
|
648 |
+
total_len = 0
|
649 |
+
indexs = []
|
650 |
+
if debug:
|
651 |
+
print(conversation)
|
652 |
+
|
653 |
+
tokenized = tokenizer.build_input_ids([conversation],
|
654 |
+
max_length=1024,
|
655 |
+
add_special_tokens=True,
|
656 |
+
truncation=False,
|
657 |
+
padding=False,
|
658 |
+
return_tensors='pt',
|
659 |
+
image=image_tensor,
|
660 |
+
video=video_tensor,
|
661 |
+
require_video=True)
|
662 |
+
if video_tensor is not None:
|
663 |
+
generation_output = self.generate_caption(
|
664 |
+
tokenized['input_ids'].unsqueeze(0).to(self.device),
|
665 |
+
tokenized['attention_mask'].unsqueeze(0).to(self.device),
|
666 |
+
video_idx = tokenized['video_index'].unsqueeze(0),
|
667 |
+
video = video_tensor.unsqueeze(0).to(self.device,dtype=torch.bfloat16),
|
668 |
+
do_sample=False
|
669 |
+
)
|
670 |
+
elif image_tensor is not None:
|
671 |
+
generation_output = self.generate_caption(
|
672 |
+
tokenized['input_ids'].unsqueeze(0).to(self.device),
|
673 |
+
tokenized['attention_mask'].unsqueeze(0).to(self.device),
|
674 |
+
image_idx = tokenized['image_index'].unsqueeze(0),
|
675 |
+
image = image_tensor.unsqueeze(0).to(self.device,dtype=torch.bfloat16),
|
676 |
+
do_sample=False
|
677 |
+
)
|
678 |
+
response = tokenizer.batch_decode(generation_output, skip_special_tokens=True)[0]
|
679 |
+
if debug:
|
680 |
+
print(response)
|
681 |
+
return response, chat_history
|
modeling_vit.py
ADDED
@@ -0,0 +1,487 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
import torch.utils.checkpoint as checkpoint
|
7 |
+
from functools import partial
|
8 |
+
|
9 |
+
from timm.models.layers import drop_path, to_2tuple, trunc_normal_
|
10 |
+
|
11 |
+
logger = logging.getLogger(__name__)
|
12 |
+
|
13 |
+
|
14 |
+
def _cfg(url='', **kwargs):
|
15 |
+
return {
|
16 |
+
'url': url,
|
17 |
+
'num_classes': 400, 'input_size': (3, 224, 224), 'pool_size': None,
|
18 |
+
'crop_pct': .9, 'interpolation': 'bicubic',
|
19 |
+
'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
|
20 |
+
**kwargs
|
21 |
+
}
|
22 |
+
|
23 |
+
class MLP(nn.Module):
|
24 |
+
"""Very simple multi-layer perceptron (also called FFN)"""
|
25 |
+
|
26 |
+
def __init__(self, input_dim, hidden_dim, output_dim, num_layers, dropout=0):
|
27 |
+
super().__init__()
|
28 |
+
self.num_layers = num_layers
|
29 |
+
h = [hidden_dim] * (num_layers - 1)
|
30 |
+
self.layers = nn.ModuleList(
|
31 |
+
nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
|
32 |
+
)
|
33 |
+
self.dropout = dropout
|
34 |
+
if dropout:
|
35 |
+
self.dropout = nn.Dropout(dropout)
|
36 |
+
|
37 |
+
def forward(self, x):
|
38 |
+
for i, layer in enumerate(self.layers):
|
39 |
+
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
40 |
+
if self.dropout and i < self.num_layers:
|
41 |
+
x = self.dropout(x)
|
42 |
+
return x
|
43 |
+
|
44 |
+
class PostProcess(nn.Module):
|
45 |
+
""" This module converts the model's output into the format expected by the coco api"""
|
46 |
+
|
47 |
+
@torch.no_grad()
|
48 |
+
def forward(self, out_sted, frames_id):
|
49 |
+
"""Perform the computation for inference evaluation
|
50 |
+
"""
|
51 |
+
# import pdb; pdb.set_trace()
|
52 |
+
|
53 |
+
b, t, _ = out_sted.shape
|
54 |
+
device = out_sted.device
|
55 |
+
temp_prob_map = torch.zeros(b,t,t).to(device)
|
56 |
+
inf = -1e32
|
57 |
+
for i_b in range(len(frames_id)):
|
58 |
+
duration = len(frames_id[0])
|
59 |
+
sted_prob = (torch.ones(t, t) * inf).tril(0).to(device)
|
60 |
+
sted_prob[duration:,:] = inf
|
61 |
+
sted_prob[:,duration:] = inf
|
62 |
+
temp_prob_map[i_b,:,:] = sted_prob
|
63 |
+
|
64 |
+
temp_prob_map += F.log_softmax(out_sted[:, :, 0], dim=1).unsqueeze(2) + \
|
65 |
+
F.log_softmax(out_sted[:, :, 1], dim=1).unsqueeze(1)
|
66 |
+
|
67 |
+
pred_steds = []
|
68 |
+
for i_b in range(b):
|
69 |
+
prob_map = temp_prob_map[i_b] # [T * T]
|
70 |
+
frame_id_seq = frames_id[i_b]
|
71 |
+
prob_seq = prob_map.flatten(0)
|
72 |
+
max_tstamp = prob_seq.max(dim=0)[1].item()
|
73 |
+
start_idx = max_tstamp // t
|
74 |
+
end_idx = max_tstamp % t
|
75 |
+
pred_sted = [frame_id_seq[start_idx], frame_id_seq[end_idx]+1]
|
76 |
+
pred_steds.append(pred_sted)
|
77 |
+
|
78 |
+
return pred_steds
|
79 |
+
|
80 |
+
class DropPath(nn.Module):
|
81 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
82 |
+
"""
|
83 |
+
def __init__(self, drop_prob=None):
|
84 |
+
super(DropPath, self).__init__()
|
85 |
+
self.drop_prob = drop_prob
|
86 |
+
|
87 |
+
def forward(self, x):
|
88 |
+
return drop_path(x, self.drop_prob, self.training)
|
89 |
+
|
90 |
+
def extra_repr(self) -> str:
|
91 |
+
return 'p={}'.format(self.drop_prob)
|
92 |
+
|
93 |
+
|
94 |
+
class Mlp(nn.Module):
|
95 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
96 |
+
super().__init__()
|
97 |
+
out_features = out_features or in_features
|
98 |
+
hidden_features = hidden_features or in_features
|
99 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
100 |
+
self.act = act_layer()
|
101 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
102 |
+
self.drop = nn.Dropout(drop)
|
103 |
+
|
104 |
+
def forward(self, x):
|
105 |
+
x = self.fc1(x)
|
106 |
+
x = self.act(x)
|
107 |
+
x = self.drop(x)
|
108 |
+
x = self.fc2(x)
|
109 |
+
x = self.drop(x)
|
110 |
+
return x
|
111 |
+
|
112 |
+
|
113 |
+
class Attention(nn.Module):
|
114 |
+
def __init__(
|
115 |
+
self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
|
116 |
+
proj_drop=0., attn_head_dim=None):
|
117 |
+
super().__init__()
|
118 |
+
self.num_heads = num_heads
|
119 |
+
head_dim = dim // num_heads
|
120 |
+
if attn_head_dim is not None:
|
121 |
+
head_dim = attn_head_dim
|
122 |
+
all_head_dim = head_dim * self.num_heads
|
123 |
+
self.scale = qk_scale or head_dim ** -0.5
|
124 |
+
|
125 |
+
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
|
126 |
+
if qkv_bias:
|
127 |
+
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
|
128 |
+
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
|
129 |
+
else:
|
130 |
+
self.q_bias = None
|
131 |
+
self.v_bias = None
|
132 |
+
|
133 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
134 |
+
self.proj = nn.Linear(all_head_dim, dim)
|
135 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
136 |
+
|
137 |
+
def forward(self, x):
|
138 |
+
B, N, C = x.shape
|
139 |
+
qkv_bias = None
|
140 |
+
if self.q_bias is not None:
|
141 |
+
qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
|
142 |
+
# qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
143 |
+
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
|
144 |
+
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
145 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
146 |
+
|
147 |
+
q = q * self.scale
|
148 |
+
attn = (q @ k.transpose(-2, -1))
|
149 |
+
|
150 |
+
attn = attn.softmax(dim=-1)
|
151 |
+
attn = self.attn_drop(attn)
|
152 |
+
|
153 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
|
154 |
+
x = self.proj(x)
|
155 |
+
x = self.proj_drop(x)
|
156 |
+
return x
|
157 |
+
|
158 |
+
|
159 |
+
class Block(nn.Module):
|
160 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
161 |
+
drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
|
162 |
+
attn_head_dim=None):
|
163 |
+
super().__init__()
|
164 |
+
self.norm1 = norm_layer(dim)
|
165 |
+
self.attn = Attention(
|
166 |
+
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
167 |
+
attn_drop=attn_drop, proj_drop=drop, attn_head_dim=attn_head_dim)
|
168 |
+
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
169 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
170 |
+
self.norm2 = norm_layer(dim)
|
171 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
172 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
173 |
+
|
174 |
+
if init_values > 0:
|
175 |
+
self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
|
176 |
+
self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
|
177 |
+
else:
|
178 |
+
self.gamma_1, self.gamma_2 = None, None
|
179 |
+
|
180 |
+
def forward(self, x):
|
181 |
+
if self.gamma_1 is None:
|
182 |
+
x = x + self.drop_path(self.attn(self.norm1(x)))
|
183 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
184 |
+
else:
|
185 |
+
x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x)))
|
186 |
+
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
|
187 |
+
return x
|
188 |
+
|
189 |
+
|
190 |
+
class PatchEmbed(nn.Module):
|
191 |
+
""" Image to Patch Embedding
|
192 |
+
"""
|
193 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, num_frames=16, tubelet_size=2):
|
194 |
+
super().__init__()
|
195 |
+
img_size = to_2tuple(img_size)
|
196 |
+
patch_size = to_2tuple(patch_size)
|
197 |
+
self.tubelet_size = int(tubelet_size)
|
198 |
+
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) * (num_frames // self.tubelet_size)
|
199 |
+
self.img_size = img_size
|
200 |
+
self.patch_size = patch_size
|
201 |
+
self.num_patches = num_patches
|
202 |
+
self.proj = nn.Conv3d(
|
203 |
+
in_channels=in_chans, out_channels=embed_dim,
|
204 |
+
kernel_size=(self.tubelet_size, patch_size[0], patch_size[1]),
|
205 |
+
stride=(self.tubelet_size, patch_size[0], patch_size[1])
|
206 |
+
)
|
207 |
+
logger.info(f'Num of patches: {num_patches}')
|
208 |
+
|
209 |
+
def forward(self, x, **kwargs):
|
210 |
+
B, C, T, H, W = x.shape
|
211 |
+
# FIXME look at relaxing size constraints
|
212 |
+
# assert H == self.img_size[0] and W == self.img_size[1], \
|
213 |
+
# f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
214 |
+
x = self.proj(x).flatten(2).transpose(1, 2)
|
215 |
+
return x
|
216 |
+
|
217 |
+
# sin-cos position encoding
|
218 |
+
# https://github.com/jadore801120/attention-is-all-you-need-pytorch/blob/master/transformer/Models.py#L31
|
219 |
+
def get_sinusoid_encoding_table(n_position, d_hid, ckpt_num_frame=-1, cur_frame=12):
|
220 |
+
''' Sinusoid position encoding table '''
|
221 |
+
# TODO: make it with torch instead of numpy
|
222 |
+
def get_position_angle_vec(position):
|
223 |
+
return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
|
224 |
+
|
225 |
+
if ckpt_num_frame != -1 and ckpt_num_frame != cur_frame:
|
226 |
+
logger.info(f"Interpolate position embedding")
|
227 |
+
logger.info(f"Testing frame: {cur_frame}")
|
228 |
+
logger.info(f"Checkpoint frame: {ckpt_num_frame}")
|
229 |
+
|
230 |
+
T = ckpt_num_frame # checkpoint frame
|
231 |
+
new_T = cur_frame # testing frame
|
232 |
+
n_position = n_position // new_T * T # generate checkpoint position embedding
|
233 |
+
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
|
234 |
+
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
|
235 |
+
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
|
236 |
+
sinusoid_table = torch.tensor(sinusoid_table, dtype=torch.float, requires_grad=False).unsqueeze(0)
|
237 |
+
# interpolate
|
238 |
+
P = int((n_position // T) ** 0.5)
|
239 |
+
C = d_hid
|
240 |
+
sinusoid_table = sinusoid_table.reshape(-1, T, P, P, C)
|
241 |
+
sinusoid_table = sinusoid_table.permute(0, 2, 3, 4, 1).reshape(-1, C, T) # BHW, C, T
|
242 |
+
sinusoid_table = torch.nn.functional.interpolate(sinusoid_table, size=new_T, mode='linear')
|
243 |
+
sinusoid_table = sinusoid_table.reshape(1, P, P, C, new_T).permute(0, 4, 1, 2, 3) # B, T, H, W, C
|
244 |
+
sinusoid_table = sinusoid_table.flatten(1, 3)
|
245 |
+
return sinusoid_table
|
246 |
+
else:
|
247 |
+
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
|
248 |
+
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
|
249 |
+
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
|
250 |
+
return torch.tensor(sinusoid_table, dtype=torch.float, requires_grad=False).unsqueeze(0)
|
251 |
+
|
252 |
+
|
253 |
+
def get_sinusoid_encoding_table2(n_position=784, d_hid=1024, cur_frame=8, ckpt_num_frame=4, pre_n_position=784):
|
254 |
+
''' Sinusoid position encoding table '''
|
255 |
+
# TODO: make it with torch instead of numpy
|
256 |
+
def get_position_angle_vec(position):
|
257 |
+
return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
|
258 |
+
|
259 |
+
# generate checkpoint position embedding
|
260 |
+
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(pre_n_position)])
|
261 |
+
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
|
262 |
+
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
|
263 |
+
sinusoid_table = torch.tensor(sinusoid_table, dtype=torch.float, requires_grad=False).unsqueeze(0)
|
264 |
+
|
265 |
+
print(f"n_position: {n_position}")
|
266 |
+
print(f"pre_n_position: {pre_n_position}")
|
267 |
+
|
268 |
+
if n_position != pre_n_position:
|
269 |
+
T = ckpt_num_frame # checkpoint frame
|
270 |
+
P = 14 # checkpoint size
|
271 |
+
C = d_hid
|
272 |
+
new_P = int((n_position // cur_frame) ** 0.5) # testing size
|
273 |
+
print(f'Pretraining uses 14x14, but current version is {new_P}x{new_P}')
|
274 |
+
print(f'Interpolate the position embedding')
|
275 |
+
sinusoid_table = sinusoid_table.reshape(-1, T, P, P, C)
|
276 |
+
sinusoid_table = sinusoid_table.reshape(-1, P, P, C).permute(0, 3, 1, 2)
|
277 |
+
sinusoid_table = torch.nn.functional.interpolate(
|
278 |
+
sinusoid_table, size=(new_P, new_P), mode='bicubic', align_corners=False)
|
279 |
+
# BT, C, H, W -> BT, H, W, C -> B, T, H, W, C
|
280 |
+
sinusoid_table = sinusoid_table.permute(0, 2, 3, 1).reshape(-1, T, new_P, new_P, C)
|
281 |
+
sinusoid_table = sinusoid_table.flatten(1, 3) # B, THW, C
|
282 |
+
|
283 |
+
if cur_frame != ckpt_num_frame:
|
284 |
+
print(f'Pretraining uses 4 frames, but current frame is {cur_frame}')
|
285 |
+
print(f'Interpolate the position embedding')
|
286 |
+
T = ckpt_num_frame # checkpoint frame
|
287 |
+
new_T = cur_frame # testing frame
|
288 |
+
# interpolate
|
289 |
+
P = int((n_position // cur_frame) ** 0.5) # testing size
|
290 |
+
C = d_hid
|
291 |
+
sinusoid_table = sinusoid_table.reshape(-1, T, P, P, C)
|
292 |
+
sinusoid_table = sinusoid_table.permute(0, 2, 3, 4, 1).reshape(-1, C, T) # BHW, C, T
|
293 |
+
sinusoid_table = torch.nn.functional.interpolate(sinusoid_table, size=new_T, mode='linear')
|
294 |
+
sinusoid_table = sinusoid_table.reshape(1, P, P, C, new_T).permute(0, 4, 1, 2, 3) # B, T, H, W, C
|
295 |
+
sinusoid_table = sinusoid_table.flatten(1, 3) # B, THW, C
|
296 |
+
|
297 |
+
return sinusoid_table
|
298 |
+
|
299 |
+
|
300 |
+
class PretrainVisionTransformerEncoder(nn.Module):
|
301 |
+
""" Vision Transformer with support for patch or hybrid CNN input stage
|
302 |
+
"""
|
303 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, depth=12,
|
304 |
+
num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
|
305 |
+
drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None, num_frames=8, tubelet_size=1,
|
306 |
+
use_learnable_pos_emb=False,
|
307 |
+
use_checkpoint=False, checkpoint_num=0,
|
308 |
+
ckpt_num_frame=-1, with_ln=True, return_index=-1
|
309 |
+
):
|
310 |
+
super().__init__()
|
311 |
+
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
312 |
+
self.patch_embed = PatchEmbed(
|
313 |
+
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
|
314 |
+
num_frames=num_frames, tubelet_size=tubelet_size
|
315 |
+
)
|
316 |
+
num_patches = self.patch_embed.num_patches
|
317 |
+
self.depth = depth + return_index + 1
|
318 |
+
self.use_checkpoint = use_checkpoint
|
319 |
+
self.checkpoint_num = checkpoint_num
|
320 |
+
logger.info(f"Use checkpoint: {use_checkpoint}")
|
321 |
+
logger.info(f"Checkpoint number: {checkpoint_num}")
|
322 |
+
logger.info(f"Real runing depth: {self.depth}")
|
323 |
+
|
324 |
+
# TODO: Add the cls token
|
325 |
+
if use_learnable_pos_emb:
|
326 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
|
327 |
+
self.img_pos_embed = nn.Parameter(torch.zeros(1, num_patches//(num_frames//tubelet_size) + 1, embed_dim))
|
328 |
+
else:
|
329 |
+
# sine-cosine positional embeddings
|
330 |
+
if img_size != 224:
|
331 |
+
self.pos_embed = get_sinusoid_encoding_table2(num_patches, embed_dim, ckpt_num_frame=ckpt_num_frame, cur_frame=num_frames//tubelet_size)
|
332 |
+
self.img_pos_embed = get_sinusoid_encoding_table2(num_patches//(num_frames//tubelet_size), embed_dim, cur_frame=1, ckpt_num_frame=1, pre_n_position=14*14)
|
333 |
+
else:
|
334 |
+
self.pos_embed = get_sinusoid_encoding_table(num_patches, embed_dim, ckpt_num_frame=ckpt_num_frame, cur_frame=num_frames//tubelet_size)
|
335 |
+
self.img_pos_embed = get_sinusoid_encoding_table(num_patches//(num_frames//tubelet_size), embed_dim)
|
336 |
+
|
337 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
338 |
+
self.blocks = nn.ModuleList([
|
339 |
+
Block(
|
340 |
+
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
341 |
+
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
|
342 |
+
init_values=init_values)
|
343 |
+
for i in range(self.depth)])
|
344 |
+
|
345 |
+
if with_ln:
|
346 |
+
self.norm = norm_layer(embed_dim)
|
347 |
+
else:
|
348 |
+
self.norm = nn.Identity()
|
349 |
+
|
350 |
+
if use_learnable_pos_emb:
|
351 |
+
trunc_normal_(self.pos_embed, std=.02)
|
352 |
+
|
353 |
+
@torch.jit.ignore
|
354 |
+
def no_weight_decay(self):
|
355 |
+
return {'pos_embed', 'cls_token'}
|
356 |
+
|
357 |
+
def forward_features(self, x, use_image=False):
|
358 |
+
x = self.patch_embed(x)
|
359 |
+
|
360 |
+
if use_image:
|
361 |
+
x = x + self.img_pos_embed.type_as(x).to(x.device).clone().detach()
|
362 |
+
else:
|
363 |
+
x = x + self.pos_embed.type_as(x).to(x.device).clone().detach()
|
364 |
+
|
365 |
+
B, _, C = x.shape
|
366 |
+
x_vis = x
|
367 |
+
|
368 |
+
for idx, blk in enumerate(self.blocks):
|
369 |
+
if self.use_checkpoint and idx < self.checkpoint_num:
|
370 |
+
x_vis = checkpoint.checkpoint(blk, x_vis)
|
371 |
+
else:
|
372 |
+
x_vis = blk(x_vis)
|
373 |
+
|
374 |
+
# with ln ot not
|
375 |
+
x_vis = self.norm(x_vis)
|
376 |
+
return x_vis
|
377 |
+
|
378 |
+
def forward(self, x, use_image=False):
|
379 |
+
x_vis = self.forward_features(x, use_image)
|
380 |
+
return x_vis
|
381 |
+
|
382 |
+
|
383 |
+
class PretrainVisionTransformer(nn.Module):
|
384 |
+
""" Vision Transformer with support for patch or hybrid CNN input stage
|
385 |
+
"""
|
386 |
+
def __init__(self,
|
387 |
+
img_size=224,
|
388 |
+
patch_size=16,
|
389 |
+
encoder_in_chans=3,
|
390 |
+
encoder_embed_dim=768,
|
391 |
+
encoder_depth=12,
|
392 |
+
encoder_num_heads=12,
|
393 |
+
mlp_ratio=4.,
|
394 |
+
qkv_bias=True,
|
395 |
+
qk_scale=None,
|
396 |
+
drop_rate=0.,
|
397 |
+
attn_drop_rate=0.,
|
398 |
+
drop_path_rate=0.,
|
399 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
400 |
+
init_values=0.,
|
401 |
+
use_learnable_pos_emb=False,
|
402 |
+
num_frames=8,
|
403 |
+
tubelet_size=1,
|
404 |
+
use_checkpoint=False,
|
405 |
+
checkpoint_num=0,
|
406 |
+
ckpt_num_frame=4, # the pretrained model uses 4 frames
|
407 |
+
return_index=-1,
|
408 |
+
with_ln=False
|
409 |
+
):
|
410 |
+
super().__init__()
|
411 |
+
|
412 |
+
self.encoder = PretrainVisionTransformerEncoder(
|
413 |
+
img_size=img_size,
|
414 |
+
patch_size=patch_size,
|
415 |
+
in_chans=encoder_in_chans,
|
416 |
+
embed_dim=encoder_embed_dim,
|
417 |
+
depth=encoder_depth,
|
418 |
+
num_heads=encoder_num_heads,
|
419 |
+
mlp_ratio=mlp_ratio,
|
420 |
+
qkv_bias=qkv_bias,
|
421 |
+
qk_scale=qk_scale,
|
422 |
+
drop_rate=drop_rate,
|
423 |
+
attn_drop_rate=attn_drop_rate,
|
424 |
+
drop_path_rate=drop_path_rate,
|
425 |
+
norm_layer=norm_layer,
|
426 |
+
init_values=init_values,
|
427 |
+
num_frames=num_frames,
|
428 |
+
tubelet_size=tubelet_size,
|
429 |
+
use_learnable_pos_emb=use_learnable_pos_emb,
|
430 |
+
use_checkpoint=use_checkpoint,
|
431 |
+
checkpoint_num=checkpoint_num,
|
432 |
+
ckpt_num_frame=ckpt_num_frame,
|
433 |
+
with_ln=with_ln,
|
434 |
+
return_index=return_index
|
435 |
+
)
|
436 |
+
logger.info(f'With LN: {with_ln}')
|
437 |
+
logger.info(f'Total {encoder_depth} layer')
|
438 |
+
logger.info(f'Return {encoder_depth+return_index+1}-th layer')
|
439 |
+
|
440 |
+
self.apply(self._init_weights)
|
441 |
+
|
442 |
+
def _init_weights(self, m):
|
443 |
+
if isinstance(m, nn.Linear):
|
444 |
+
nn.init.xavier_uniform_(m.weight)
|
445 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
446 |
+
nn.init.constant_(m.bias, 0)
|
447 |
+
elif isinstance(m, nn.LayerNorm):
|
448 |
+
nn.init.constant_(m.bias, 0)
|
449 |
+
nn.init.constant_(m.weight, 1.0)
|
450 |
+
|
451 |
+
@torch.jit.ignore
|
452 |
+
def no_weight_decay(self):
|
453 |
+
return {'pos_embed', 'cls_token', 'clip_pos_embed'}
|
454 |
+
|
455 |
+
def forward(self, x, use_image=False):
|
456 |
+
T = x.shape[2]
|
457 |
+
x_vis = self.encoder(x, use_image) # [B, N_vis, C_e]
|
458 |
+
B, TL, C = x_vis.shape
|
459 |
+
x_vis = x_vis.view(B, T, TL // T, C)
|
460 |
+
|
461 |
+
return x_vis
|
462 |
+
|
463 |
+
|
464 |
+
def build_vit(config):
|
465 |
+
model = PretrainVisionTransformer(
|
466 |
+
img_size=config.vision_encoder.img_size,
|
467 |
+
patch_size=config.vision_encoder.patch_size,
|
468 |
+
encoder_embed_dim=config.vision_encoder.encoder_embed_dim,
|
469 |
+
encoder_depth=config.vision_encoder.encoder_depth,
|
470 |
+
encoder_num_heads=config.vision_encoder.encoder_num_heads,
|
471 |
+
drop_path_rate=config.vision_encoder.drop_path_rate,
|
472 |
+
num_frames=config.vision_encoder.num_frames,
|
473 |
+
tubelet_size=config.vision_encoder.tubelet_size,
|
474 |
+
use_checkpoint=config.vision_encoder.use_checkpoint,
|
475 |
+
checkpoint_num=config.vision_encoder.checkpoint_num,
|
476 |
+
return_index=config.vision_encoder.get('return_index', -1),
|
477 |
+
with_ln=config.vision_encoder.get('with_ln', False),
|
478 |
+
)
|
479 |
+
model.default_cfg = _cfg()
|
480 |
+
if config.vision_encoder.pretrained:
|
481 |
+
logger.info(f"Loading pretrained weights from {config.vision_encoder.pretrained}")
|
482 |
+
state_dict = torch.load(config.vision_encoder.pretrained, map_location='cpu')
|
483 |
+
model.load_state_dict(state_dict, strict=False)
|
484 |
+
else:
|
485 |
+
logger.info("No pretrained weights!!!")
|
486 |
+
return model
|
487 |
+
|
special_tokens_map.json
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"bos_token": {
|
3 |
+
"content": "<s>",
|
4 |
+
"lstrip": false,
|
5 |
+
"normalized": false,
|
6 |
+
"rstrip": false,
|
7 |
+
"single_word": false
|
8 |
+
},
|
9 |
+
"eos_token": {
|
10 |
+
"content": "</s>",
|
11 |
+
"lstrip": false,
|
12 |
+
"normalized": false,
|
13 |
+
"rstrip": false,
|
14 |
+
"single_word": false
|
15 |
+
},
|
16 |
+
"pad_token": "<unk>",
|
17 |
+
"unk_token": {
|
18 |
+
"content": "<unk>",
|
19 |
+
"lstrip": false,
|
20 |
+
"normalized": false,
|
21 |
+
"rstrip": false,
|
22 |
+
"single_word": false
|
23 |
+
}
|
24 |
+
}
|
third_party/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
logger = logging.getLogger(__name__)
|
third_party/cgdetr/cg_detr/__init__.py
ADDED
File without changes
|
third_party/cgdetr/cg_detr/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (168 Bytes). View file
|
|
third_party/cgdetr/cg_detr/__pycache__/attention.cpython-310.pyc
ADDED
Binary file (15.2 kB). View file
|
|
third_party/cgdetr/cg_detr/__pycache__/crossattention.cpython-310.pyc
ADDED
Binary file (15.4 kB). View file
|
|
third_party/cgdetr/cg_detr/__pycache__/matcher.cpython-310.pyc
ADDED
Binary file (4.65 kB). View file
|
|
third_party/cgdetr/cg_detr/__pycache__/misc.cpython-310.pyc
ADDED
Binary file (714 Bytes). View file
|
|
third_party/cgdetr/cg_detr/__pycache__/model.cpython-310.pyc
ADDED
Binary file (32.6 kB). View file
|
|
third_party/cgdetr/cg_detr/__pycache__/position_encoding.cpython-310.pyc
ADDED
Binary file (4.33 kB). View file
|
|
third_party/cgdetr/cg_detr/__pycache__/span_utils.cpython-310.pyc
ADDED
Binary file (4.19 kB). View file
|
|
third_party/cgdetr/cg_detr/__pycache__/transformer.cpython-310.pyc
ADDED
Binary file (22.7 kB). View file
|
|
third_party/cgdetr/cg_detr/attention.py
ADDED
@@ -0,0 +1,394 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------
|
2 |
+
# DAB-DETR
|
3 |
+
# Copyright (c) 2022 IDEA. All Rights Reserved.
|
4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
5 |
+
# ------------------------------------------------------------------------
|
6 |
+
# Modified from Conditional DETR (https://github.com/Atten4Vis/ConditionalDETR)
|
7 |
+
# Copyright (c) 2021 Microsoft. All Rights Reserved.
|
8 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
9 |
+
# ------------------------------------------------------------------------
|
10 |
+
# Modified from DETR (https://github.com/facebookresearch/detr)
|
11 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
12 |
+
# ------------------------------------------------------------------------
|
13 |
+
# Modified from codes in torch.nn
|
14 |
+
# ------------------------------------------------------------------------
|
15 |
+
|
16 |
+
"""
|
17 |
+
MultiheadAttention that support query, key, and value to have different dimensions.
|
18 |
+
Query, key, and value projections are removed.
|
19 |
+
Mostly copy-paste from https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/activation.py#L873
|
20 |
+
and https://github.com/pytorch/pytorch/blob/master/torch/nn/functional.py#L4837
|
21 |
+
"""
|
22 |
+
|
23 |
+
import copy
|
24 |
+
from typing import Optional, List
|
25 |
+
|
26 |
+
import torch
|
27 |
+
import torch.nn.functional as F
|
28 |
+
from torch import nn, Tensor
|
29 |
+
|
30 |
+
import warnings
|
31 |
+
from typing import Tuple, Optional
|
32 |
+
|
33 |
+
import torch
|
34 |
+
from torch import Tensor
|
35 |
+
from torch.nn.modules.linear import Linear
|
36 |
+
from torch.nn.init import xavier_uniform_
|
37 |
+
from torch.nn.init import constant_
|
38 |
+
from torch.nn.init import xavier_normal_
|
39 |
+
from torch.nn.parameter import Parameter
|
40 |
+
from torch.nn.modules.module import Module
|
41 |
+
from torch.nn import functional as F
|
42 |
+
|
43 |
+
import warnings
|
44 |
+
import math
|
45 |
+
|
46 |
+
from torch._C import _infer_size, _add_docstr
|
47 |
+
from torch.nn import _reduction as _Reduction
|
48 |
+
from torch.nn.modules import utils
|
49 |
+
from torch.nn.modules.utils import _single, _pair, _triple, _list_with_default
|
50 |
+
from torch.nn import grad
|
51 |
+
from torch import _VF
|
52 |
+
from torch._jit_internal import boolean_dispatch, List, Optional, _overload, Tuple
|
53 |
+
try:
|
54 |
+
from torch.overrides import has_torch_function, handle_torch_function
|
55 |
+
except:
|
56 |
+
from torch._overrides import has_torch_function, handle_torch_function
|
57 |
+
Tensor = torch.Tensor
|
58 |
+
|
59 |
+
from torch.nn.functional import linear, pad, softmax, dropout
|
60 |
+
|
61 |
+
class MultiheadAttention(Module):
|
62 |
+
r"""Allows the model to jointly attend to information
|
63 |
+
from different representation subspaces.
|
64 |
+
See reference: Attention Is All You Need
|
65 |
+
.. math::
|
66 |
+
\text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
|
67 |
+
\text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)
|
68 |
+
Args:
|
69 |
+
embed_dim: total dimension of the model.
|
70 |
+
num_heads: parallel attention heads.
|
71 |
+
dropout: a Dropout layer on attn_output_weights. Default: 0.0.
|
72 |
+
bias: add bias as module parameter. Default: True.
|
73 |
+
add_bias_kv: add bias to the key and value sequences at dim=0.
|
74 |
+
add_zero_attn: add a new batch of zeros to the key and
|
75 |
+
value sequences at dim=1.
|
76 |
+
kdim: total number of features in key. Default: None.
|
77 |
+
vdim: total number of features in value. Default: None.
|
78 |
+
Note: if kdim and vdim are None, they will be set to embed_dim such that
|
79 |
+
query, key, and value have the same number of features.
|
80 |
+
Examples::
|
81 |
+
>>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
|
82 |
+
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)
|
83 |
+
"""
|
84 |
+
bias_k: Optional[torch.Tensor]
|
85 |
+
bias_v: Optional[torch.Tensor]
|
86 |
+
|
87 |
+
def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None):
|
88 |
+
super(MultiheadAttention, self).__init__()
|
89 |
+
self.embed_dim = embed_dim
|
90 |
+
self.kdim = kdim if kdim is not None else embed_dim
|
91 |
+
self.vdim = vdim if vdim is not None else embed_dim
|
92 |
+
self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
|
93 |
+
|
94 |
+
self.num_heads = num_heads
|
95 |
+
self.dropout = dropout
|
96 |
+
self.head_dim = embed_dim // num_heads
|
97 |
+
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
|
98 |
+
|
99 |
+
vdim = vdim if vdim is not None else embed_dim
|
100 |
+
self.out_proj = Linear(vdim , vdim)
|
101 |
+
|
102 |
+
self.in_proj_bias = None
|
103 |
+
self.in_proj_weight = None
|
104 |
+
self.bias_k = self.bias_v = None
|
105 |
+
self.q_proj_weight = None
|
106 |
+
self.k_proj_weight = None
|
107 |
+
self.v_proj_weight = None
|
108 |
+
|
109 |
+
self.add_zero_attn = add_zero_attn
|
110 |
+
|
111 |
+
self._reset_parameters()
|
112 |
+
|
113 |
+
def _reset_parameters(self):
|
114 |
+
constant_(self.out_proj.bias, 0.)
|
115 |
+
|
116 |
+
def __setstate__(self, state):
|
117 |
+
# Support loading old MultiheadAttention checkpoints generated by v1.1.0
|
118 |
+
if '_qkv_same_embed_dim' not in state:
|
119 |
+
state['_qkv_same_embed_dim'] = True
|
120 |
+
|
121 |
+
super(MultiheadAttention, self).__setstate__(state)
|
122 |
+
|
123 |
+
def forward(self, query, key, value, key_padding_mask=None,
|
124 |
+
need_weights=True, attn_mask=None):
|
125 |
+
# type: (Tensor, Tensor, Tensor, Optional[Tensor], bool, Optional[Tensor]) -> Tuple[Tensor, Optional[Tensor]]
|
126 |
+
r"""
|
127 |
+
Args:
|
128 |
+
query, key, value: map a query and a set of key-value pairs to an output.
|
129 |
+
See "Attention Is All You Need" for more details.
|
130 |
+
key_padding_mask: if provided, specified padding elements in the key will
|
131 |
+
be ignored by the attention. When given a binary mask and a value is True,
|
132 |
+
the corresponding value on the attention layer will be ignored. When given
|
133 |
+
a byte mask and a value is non-zero, the corresponding value on the attention
|
134 |
+
layer will be ignored
|
135 |
+
need_weights: output attn_output_weights.
|
136 |
+
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
|
137 |
+
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
|
138 |
+
Shape:
|
139 |
+
- Inputs:
|
140 |
+
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
|
141 |
+
the embedding dimension.
|
142 |
+
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
|
143 |
+
the embedding dimension.
|
144 |
+
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
|
145 |
+
the embedding dimension.
|
146 |
+
- key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
|
147 |
+
If a ByteTensor is provided, the non-zero positions will be ignored while the position
|
148 |
+
with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the
|
149 |
+
value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
|
150 |
+
- attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
|
151 |
+
3D mask :math:`(N*\text{num_heads}, L, S)` where N is the batch size, L is the target sequence length,
|
152 |
+
S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked
|
153 |
+
positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
|
154 |
+
while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
|
155 |
+
is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
|
156 |
+
is provided, it will be added to the attention weight.
|
157 |
+
- Outputs:
|
158 |
+
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
|
159 |
+
E is the embedding dimension.
|
160 |
+
- attn_output_weights: :math:`(N, L, S)` where N is the batch size,
|
161 |
+
L is the target sequence length, S is the source sequence length.
|
162 |
+
"""
|
163 |
+
if not self._qkv_same_embed_dim:
|
164 |
+
return multi_head_attention_forward(
|
165 |
+
query, key, value, self.embed_dim, self.num_heads,
|
166 |
+
self.in_proj_weight, self.in_proj_bias,
|
167 |
+
self.bias_k, self.bias_v, self.add_zero_attn,
|
168 |
+
self.dropout, self.out_proj.weight, self.out_proj.bias,
|
169 |
+
training=self.training,
|
170 |
+
key_padding_mask=key_padding_mask, need_weights=need_weights,
|
171 |
+
attn_mask=attn_mask, use_separate_proj_weight=True,
|
172 |
+
q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
|
173 |
+
v_proj_weight=self.v_proj_weight, out_dim=self.vdim)
|
174 |
+
else:
|
175 |
+
return multi_head_attention_forward(
|
176 |
+
query, key, value, self.embed_dim, self.num_heads,
|
177 |
+
self.in_proj_weight, self.in_proj_bias,
|
178 |
+
self.bias_k, self.bias_v, self.add_zero_attn,
|
179 |
+
self.dropout, self.out_proj.weight, self.out_proj.bias,
|
180 |
+
training=self.training,
|
181 |
+
key_padding_mask=key_padding_mask, need_weights=need_weights,
|
182 |
+
attn_mask=attn_mask, out_dim=self.vdim)
|
183 |
+
|
184 |
+
|
185 |
+
def multi_head_attention_forward(query: Tensor,
|
186 |
+
key: Tensor,
|
187 |
+
value: Tensor,
|
188 |
+
embed_dim_to_check: int,
|
189 |
+
num_heads: int,
|
190 |
+
in_proj_weight: Tensor,
|
191 |
+
in_proj_bias: Tensor,
|
192 |
+
bias_k: Optional[Tensor],
|
193 |
+
bias_v: Optional[Tensor],
|
194 |
+
add_zero_attn: bool,
|
195 |
+
dropout_p: float,
|
196 |
+
out_proj_weight: Tensor,
|
197 |
+
out_proj_bias: Tensor,
|
198 |
+
training: bool = True,
|
199 |
+
key_padding_mask: Optional[Tensor] = None,
|
200 |
+
need_weights: bool = True,
|
201 |
+
attn_mask: Optional[Tensor] = None,
|
202 |
+
use_separate_proj_weight: bool = False,
|
203 |
+
q_proj_weight: Optional[Tensor] = None,
|
204 |
+
k_proj_weight: Optional[Tensor] = None,
|
205 |
+
v_proj_weight: Optional[Tensor] = None,
|
206 |
+
static_k: Optional[Tensor] = None,
|
207 |
+
static_v: Optional[Tensor] = None,
|
208 |
+
out_dim: Optional[Tensor] = None
|
209 |
+
) -> Tuple[Tensor, Optional[Tensor]]:
|
210 |
+
r"""
|
211 |
+
Args:
|
212 |
+
query, key, value: map a query and a set of key-value pairs to an output.
|
213 |
+
See "Attention Is All You Need" for more details.
|
214 |
+
embed_dim_to_check: total dimension of the model.
|
215 |
+
num_heads: parallel attention heads.
|
216 |
+
in_proj_weight, in_proj_bias: input projection weight and bias.
|
217 |
+
bias_k, bias_v: bias of the key and value sequences to be added at dim=0.
|
218 |
+
add_zero_attn: add a new batch of zeros to the key and
|
219 |
+
value sequences at dim=1.
|
220 |
+
dropout_p: probability of an element to be zeroed.
|
221 |
+
out_proj_weight, out_proj_bias: the output projection weight and bias.
|
222 |
+
training: apply dropout if is ``True``.
|
223 |
+
key_padding_mask: if provided, specified padding elements in the key will
|
224 |
+
be ignored by the attention. This is an binary mask. When the value is True,
|
225 |
+
the corresponding value on the attention layer will be filled with -inf.
|
226 |
+
need_weights: output attn_output_weights.
|
227 |
+
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
|
228 |
+
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
|
229 |
+
use_separate_proj_weight: the function accept the proj. weights for query, key,
|
230 |
+
and value in different forms. If false, in_proj_weight will be used, which is
|
231 |
+
a combination of q_proj_weight, k_proj_weight, v_proj_weight.
|
232 |
+
q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias.
|
233 |
+
static_k, static_v: static key and value used for attention operators.
|
234 |
+
Shape:
|
235 |
+
Inputs:
|
236 |
+
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
|
237 |
+
the embedding dimension.
|
238 |
+
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
|
239 |
+
the embedding dimension.
|
240 |
+
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
|
241 |
+
the embedding dimension.
|
242 |
+
- key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
|
243 |
+
If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions
|
244 |
+
will be unchanged. If a BoolTensor is provided, the positions with the
|
245 |
+
value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
|
246 |
+
- attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
|
247 |
+
3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
|
248 |
+
S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked
|
249 |
+
positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
|
250 |
+
while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
|
251 |
+
are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
|
252 |
+
is provided, it will be added to the attention weight.
|
253 |
+
- static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
|
254 |
+
N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
|
255 |
+
- static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
|
256 |
+
N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
|
257 |
+
Outputs:
|
258 |
+
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
|
259 |
+
E is the embedding dimension.
|
260 |
+
- attn_output_weights: :math:`(N, L, S)` where N is the batch size,
|
261 |
+
L is the target sequence length, S is the source sequence length.
|
262 |
+
"""
|
263 |
+
if not torch.jit.is_scripting():
|
264 |
+
tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v,
|
265 |
+
out_proj_weight, out_proj_bias)
|
266 |
+
if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):
|
267 |
+
return handle_torch_function(
|
268 |
+
multi_head_attention_forward, tens_ops, query, key, value,
|
269 |
+
embed_dim_to_check, num_heads, in_proj_weight, in_proj_bias,
|
270 |
+
bias_k, bias_v, add_zero_attn, dropout_p, out_proj_weight,
|
271 |
+
out_proj_bias, training=training, key_padding_mask=key_padding_mask,
|
272 |
+
need_weights=need_weights, attn_mask=attn_mask,
|
273 |
+
use_separate_proj_weight=use_separate_proj_weight,
|
274 |
+
q_proj_weight=q_proj_weight, k_proj_weight=k_proj_weight,
|
275 |
+
v_proj_weight=v_proj_weight, static_k=static_k, static_v=static_v)
|
276 |
+
tgt_len, bsz, embed_dim = query.size()
|
277 |
+
assert embed_dim == embed_dim_to_check
|
278 |
+
# allow MHA to have different sizes for the feature dimension
|
279 |
+
assert key.size(0) == value.size(0) and key.size(1) == value.size(1)
|
280 |
+
|
281 |
+
head_dim = embed_dim // num_heads
|
282 |
+
v_head_dim = out_dim // num_heads
|
283 |
+
assert head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"
|
284 |
+
scaling = float(head_dim) ** -0.5
|
285 |
+
|
286 |
+
q = query * scaling
|
287 |
+
k = key
|
288 |
+
v = value
|
289 |
+
|
290 |
+
if attn_mask is not None:
|
291 |
+
assert attn_mask.dtype == torch.float32 or attn_mask.dtype == torch.float64 or \
|
292 |
+
attn_mask.dtype == torch.float16 or attn_mask.dtype == torch.uint8 or attn_mask.dtype == torch.bool, \
|
293 |
+
'Only float, byte, and bool types are supported for attn_mask, not {}'.format(attn_mask.dtype)
|
294 |
+
if attn_mask.dtype == torch.uint8:
|
295 |
+
warnings.warn("Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
|
296 |
+
attn_mask = attn_mask.to(torch.bool)
|
297 |
+
|
298 |
+
if attn_mask.dim() == 2:
|
299 |
+
attn_mask = attn_mask.unsqueeze(0)
|
300 |
+
if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
|
301 |
+
raise RuntimeError('The size of the 2D attn_mask is not correct.')
|
302 |
+
elif attn_mask.dim() == 3:
|
303 |
+
if list(attn_mask.size()) != [bsz * num_heads, query.size(0), key.size(0)]:
|
304 |
+
raise RuntimeError('The size of the 3D attn_mask is not correct.')
|
305 |
+
else:
|
306 |
+
raise RuntimeError("attn_mask's dimension {} is not supported".format(attn_mask.dim()))
|
307 |
+
# attn_mask's dim is 3 now.
|
308 |
+
|
309 |
+
# convert ByteTensor key_padding_mask to bool
|
310 |
+
if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
|
311 |
+
warnings.warn("Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
|
312 |
+
key_padding_mask = key_padding_mask.to(torch.bool)
|
313 |
+
|
314 |
+
if bias_k is not None and bias_v is not None:
|
315 |
+
if static_k is None and static_v is None:
|
316 |
+
k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
|
317 |
+
v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
|
318 |
+
if attn_mask is not None:
|
319 |
+
attn_mask = pad(attn_mask, (0, 1))
|
320 |
+
if key_padding_mask is not None:
|
321 |
+
key_padding_mask = pad(key_padding_mask, (0, 1))
|
322 |
+
else:
|
323 |
+
assert static_k is None, "bias cannot be added to static key."
|
324 |
+
assert static_v is None, "bias cannot be added to static value."
|
325 |
+
else:
|
326 |
+
assert bias_k is None
|
327 |
+
assert bias_v is None
|
328 |
+
|
329 |
+
q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
|
330 |
+
if k is not None:
|
331 |
+
k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
|
332 |
+
if v is not None:
|
333 |
+
v = v.contiguous().view(-1, bsz * num_heads, v_head_dim).transpose(0, 1)
|
334 |
+
|
335 |
+
if static_k is not None:
|
336 |
+
assert static_k.size(0) == bsz * num_heads
|
337 |
+
assert static_k.size(2) == head_dim
|
338 |
+
k = static_k
|
339 |
+
|
340 |
+
if static_v is not None:
|
341 |
+
assert static_v.size(0) == bsz * num_heads
|
342 |
+
assert static_v.size(2) == v_head_dim
|
343 |
+
v = static_v
|
344 |
+
|
345 |
+
src_len = k.size(1)
|
346 |
+
|
347 |
+
if key_padding_mask is not None:
|
348 |
+
assert key_padding_mask.size(0) == bsz
|
349 |
+
assert key_padding_mask.size(1) == src_len
|
350 |
+
|
351 |
+
if add_zero_attn:
|
352 |
+
src_len += 1
|
353 |
+
k = torch.cat([k, torch.zeros((k.size(0), 1) + k.size()[2:], dtype=k.dtype, device=k.device)], dim=1)
|
354 |
+
v = torch.cat([v, torch.zeros((v.size(0), 1) + v.size()[2:], dtype=v.dtype, device=v.device)], dim=1)
|
355 |
+
if attn_mask is not None:
|
356 |
+
attn_mask = pad(attn_mask, (0, 1))
|
357 |
+
if key_padding_mask is not None:
|
358 |
+
key_padding_mask = pad(key_padding_mask, (0, 1))
|
359 |
+
|
360 |
+
attn_output_weights = torch.bmm(q, k.transpose(1, 2))
|
361 |
+
assert list(attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len]
|
362 |
+
|
363 |
+
if attn_mask is not None:
|
364 |
+
if attn_mask.dtype == torch.bool:
|
365 |
+
attn_output_weights.masked_fill_(attn_mask, float('-inf'))
|
366 |
+
else:
|
367 |
+
attn_output_weights += attn_mask
|
368 |
+
|
369 |
+
|
370 |
+
if key_padding_mask is not None:
|
371 |
+
attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
|
372 |
+
attn_output_weights = attn_output_weights.masked_fill(
|
373 |
+
key_padding_mask.unsqueeze(1).unsqueeze(2),
|
374 |
+
float('-inf'),
|
375 |
+
)
|
376 |
+
attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, src_len)
|
377 |
+
|
378 |
+
# attn_output_weights = softmax(
|
379 |
+
# attn_output_weights, dim=-1)
|
380 |
+
attn_output_weights = softmax(
|
381 |
+
attn_output_weights - attn_output_weights.max(dim=-1, keepdim=True)[0], dim=-1)
|
382 |
+
attn_output_weights = dropout(attn_output_weights, p=dropout_p, training=training)
|
383 |
+
|
384 |
+
attn_output = torch.bmm(attn_output_weights, v)
|
385 |
+
assert list(attn_output.size()) == [bsz * num_heads, tgt_len, v_head_dim]
|
386 |
+
attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, out_dim)
|
387 |
+
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
|
388 |
+
|
389 |
+
if need_weights:
|
390 |
+
# average attention weights over heads
|
391 |
+
attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
|
392 |
+
return attn_output, attn_output_weights.sum(dim=1) / num_heads
|
393 |
+
else:
|
394 |
+
return attn_output, None
|
third_party/cgdetr/cg_detr/config.py
ADDED
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
import torch
|
4 |
+
import argparse
|
5 |
+
|
6 |
+
from third_party.cgdetr.utils.basic_utils import mkdirp, load_json, save_json, make_zipfile, dict_to_markdown
|
7 |
+
import shutil
|
8 |
+
|
9 |
+
class BaseOptions(object):
|
10 |
+
saved_option_filename = "opt.json"
|
11 |
+
ckpt_filename = "model.ckpt"
|
12 |
+
tensorboard_log_dir = "tensorboard_log"
|
13 |
+
train_log_filename = "train.log.txt"
|
14 |
+
eval_log_filename = "eval.log.txt"
|
15 |
+
|
16 |
+
def __init__(self):
|
17 |
+
self.parser = None
|
18 |
+
self.initialized = False
|
19 |
+
self.opt = None
|
20 |
+
|
21 |
+
def initialize(self):
|
22 |
+
self.initialized = True
|
23 |
+
parser = argparse.ArgumentParser()
|
24 |
+
# parser.add_argument("--dset_name", type=str, choices=["hl", 'charadesSTA', ])
|
25 |
+
# parser.add_argument("--dset_domain", type=str,
|
26 |
+
# help="Domain to train for tvsum dataset. (Only used for tvsum and youtube-hl)")
|
27 |
+
|
28 |
+
parser.add_argument("--eval_split_name", type=str, default="val",
|
29 |
+
help="should match keys in video_duration_idx_path, must set for VCMR")
|
30 |
+
parser.add_argument("--debug", action="store_true",
|
31 |
+
help="debug (fast) mode, break all loops, do not load all data into memory.")
|
32 |
+
parser.add_argument("--data_ratio", type=float, default=1.0,
|
33 |
+
help="how many training and eval data to use. 1.0: use all, 0.1: use 10%."
|
34 |
+
"Use small portion for debug purposes. Note this is different from --debug, "
|
35 |
+
"which works by breaking the loops, typically they are not used together.")
|
36 |
+
parser.add_argument("--results_root", type=str, default="results")
|
37 |
+
parser.add_argument("--exp_id", type=str, default=None, help="id of this run, required at training")
|
38 |
+
parser.add_argument("--seed", type=int, default=2018, help="random seed")
|
39 |
+
# parser.add_argument("--device", type=int, default=0, help="0 cuda, -1 cpu")
|
40 |
+
parser.add_argument("--num_workers", type=int, default=0,
|
41 |
+
help="num subprocesses used to load the data, 0: use main process")
|
42 |
+
parser.add_argument("--no_pin_memory", action="store_true",
|
43 |
+
help="Don't use pin_memory=True for dataloader. "
|
44 |
+
"ref: https://discuss.pytorch.org/t/should-we-set-non-blocking-to-true/38234/4")
|
45 |
+
|
46 |
+
# training config
|
47 |
+
# parser.add_argument("--lr", type=float, default=2e-4, help="learning rate")
|
48 |
+
# parser.add_argument("--lr_drop", type=int, default=800, help="drop learning rate to 1/10 every lr_drop epochs")
|
49 |
+
# parser.add_argument("--wd", type=float, default=1e-4, help="weight decay")
|
50 |
+
parser.add_argument("--n_epoch", type=int, default=200, help="number of epochs to run")
|
51 |
+
parser.add_argument("--max_es_cnt", type=int, default=200,
|
52 |
+
help="number of epochs to early stop, use -1 to disable early stop")
|
53 |
+
# parser.add_argument("--bsz", type=int, default=32, help="mini-batch size")
|
54 |
+
# parser.add_argument("--eval_bsz", type=int, default=100,
|
55 |
+
# help="mini-batch size at inference, for query")
|
56 |
+
parser.add_argument("--eval_epoch", type=int, default=5,help="inference epoch")
|
57 |
+
parser.add_argument("--grad_clip", type=float, default=0.1, help="perform gradient clip, -1: disable")
|
58 |
+
parser.add_argument("--eval_untrained", action="store_true", help="Evaluate on un-trained model")
|
59 |
+
parser.add_argument("--resume", type=str, default=None,
|
60 |
+
help="checkpoint path to resume or evaluate, without --resume_all this only load weights")
|
61 |
+
parser.add_argument("--resume_all", action="store_true",
|
62 |
+
help="if --resume_all, load optimizer/scheduler/epoch as well")
|
63 |
+
parser.add_argument("--start_epoch", type=int, default=None,
|
64 |
+
help="if None, will be set automatically when using --resume_all")
|
65 |
+
|
66 |
+
# Data config
|
67 |
+
parser.add_argument("--max_q_l", type=int, default=-1)
|
68 |
+
parser.add_argument("--max_v_l", type=int, default=-1)
|
69 |
+
parser.add_argument("--clip_length", type=float, default=2)
|
70 |
+
parser.add_argument("--max_windows", type=int, default=5)
|
71 |
+
|
72 |
+
parser.add_argument("--train_path", type=str, default=None)
|
73 |
+
parser.add_argument("--eval_path", type=str, default=None,
|
74 |
+
help="Evaluating during training, for Dev set. If None, will only do training, ")
|
75 |
+
parser.add_argument("--no_norm_vfeat", action="store_true", help="Do not do normalize video feat")
|
76 |
+
parser.add_argument("--no_norm_tfeat", action="store_true", help="Do not do normalize text feat")
|
77 |
+
parser.add_argument("--v_feat_dirs", type=str, nargs="+",
|
78 |
+
help="video feature dirs. If more than one, will concat their features. "
|
79 |
+
"Note that sub ctx features are also accepted here.")
|
80 |
+
parser.add_argument("--t_feat_dir", type=str, help="text/query feature dir")
|
81 |
+
# parser.add_argument("--a_feat_dir", type=str, help="audio feature dir")
|
82 |
+
parser.add_argument("--v_feat_dim", type=int, default=770, help="video feature dim")
|
83 |
+
parser.add_argument("--t_feat_dim", type=int, default=4096, help="text/query feature dim")
|
84 |
+
# parser.add_argument("--a_feat_dim", type=int, help="audio feature dim")
|
85 |
+
parser.add_argument("--ctx_mode", type=str, default="video_tef")
|
86 |
+
|
87 |
+
# Model config
|
88 |
+
parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine', 'learned'),
|
89 |
+
help="Type of positional embedding to use on top of the image features")
|
90 |
+
# * Transformer
|
91 |
+
parser.add_argument('--enc_layers', default=3, type=int,
|
92 |
+
help="Number of encoding layers in the transformer")
|
93 |
+
parser.add_argument('--dec_layers', default=3, type=int,
|
94 |
+
help="Number of decoding layers in the transformer")
|
95 |
+
parser.add_argument('--t2v_layers', default=2, type=int,
|
96 |
+
help="Number of decoding layers in the transformer")
|
97 |
+
parser.add_argument('--sent_layers', default=1, type=int,
|
98 |
+
help="Number of decoding layers in the transformer")
|
99 |
+
parser.add_argument('--moment_layers', default=1, type=int,
|
100 |
+
help="Number of decoding layers in the transformer")
|
101 |
+
parser.add_argument('--dummy_layers', default=2, type=int,
|
102 |
+
help="Number of encoding layers in the transformer")
|
103 |
+
parser.add_argument('--dim_feedforward', default=1024, type=int,
|
104 |
+
help="Intermediate size of the feedforward layers in the transformer blocks")
|
105 |
+
parser.add_argument('--hidden_dim', default=256, type=int,
|
106 |
+
help="Size of the embeddings (dimension of the transformer)")
|
107 |
+
parser.add_argument('--input_dropout', default=0.5, type=float,
|
108 |
+
help="Dropout applied in input")
|
109 |
+
parser.add_argument('--dropout', default=0.1, type=float,
|
110 |
+
help="Dropout applied in the transformer")
|
111 |
+
parser.add_argument("--txt_drop_ratio", default=0, type=float,
|
112 |
+
help="drop txt_drop_ratio tokens from text input. 0.1=10%")
|
113 |
+
parser.add_argument("--use_txt_pos", action="store_true", help="use position_embedding for text as well.")
|
114 |
+
parser.add_argument('--nheads', default=8, type=int,
|
115 |
+
help="Number of attention heads inside the transformer's attentions")
|
116 |
+
parser.add_argument('--num_queries', default=10, type=int,
|
117 |
+
help="Number of query slots")
|
118 |
+
parser.add_argument('--num_dummies', default=45, type=int,
|
119 |
+
help="Number of dummy tokens")
|
120 |
+
parser.add_argument('--total_prompts', default=10, type=int,
|
121 |
+
help="Number of query slots")
|
122 |
+
parser.add_argument('--num_prompts', default=1, type=int,
|
123 |
+
help="Number of dummy tokens")
|
124 |
+
parser.add_argument('--pre_norm', action='store_true')
|
125 |
+
# other model configs
|
126 |
+
parser.add_argument("--n_input_proj", type=int, default=2, help="#layers to encoder input")
|
127 |
+
parser.add_argument("--contrastive_hdim", type=int, default=64, help="dim for contrastive embeddings")
|
128 |
+
parser.add_argument("--temperature", type=float, default=0.07, help="temperature nce contrastive_align_loss")
|
129 |
+
# Loss
|
130 |
+
|
131 |
+
parser.add_argument("--saliency_margin", type=float, default=0.2)
|
132 |
+
parser.add_argument('--no_aux_loss', dest='aux_loss', action='store_false',
|
133 |
+
help="Disables auxiliary decoding losses (loss at each layer)")
|
134 |
+
parser.add_argument("--span_loss_type", default="l1", type=str, choices=['l1', 'ce'],
|
135 |
+
help="l1: (center-x, width) regression. ce: (st_idx, ed_idx) classification.")
|
136 |
+
parser.add_argument("--contrastive_align_loss", action="store_true",
|
137 |
+
help="Disable contrastive_align_loss between matched query spans and the text.")
|
138 |
+
# * Matcher
|
139 |
+
parser.add_argument('--set_cost_span', default=10, type=float,
|
140 |
+
help="L1 span coefficient in the matching cost")
|
141 |
+
parser.add_argument('--set_cost_giou', default=1, type=float,
|
142 |
+
help="giou span coefficient in the matching cost")
|
143 |
+
parser.add_argument('--set_cost_class', default=4, type=float,
|
144 |
+
help="Class coefficient in the matching cost")
|
145 |
+
|
146 |
+
# * Loss coefficients
|
147 |
+
parser.add_argument("--lw_saliency", type=float, default=1.,
|
148 |
+
help="weight for saliency loss, set to 0 will ignore")
|
149 |
+
parser.add_argument("--lw_wattn", type=float, default=1.,
|
150 |
+
help="weight for saliency loss, set to 0 will ignore")
|
151 |
+
parser.add_argument("--lw_ms_align", type=float, default=1.,
|
152 |
+
help="weight for saliency loss, set to 0 will ignore")
|
153 |
+
parser.add_argument("--lw_distill", type=float, default=1.,
|
154 |
+
help="weight for saliency loss, set to 0 will ignore")
|
155 |
+
parser.add_argument('--span_loss_coef', default=10, type=float)
|
156 |
+
parser.add_argument('--giou_loss_coef', default=1, type=float)
|
157 |
+
parser.add_argument('--label_loss_coef', default=4, type=float)
|
158 |
+
parser.add_argument('--eos_coef', default=0.1, type=float,
|
159 |
+
help="Relative classification weight of the no-object class")
|
160 |
+
parser.add_argument("--contrastive_align_loss_coef", default=0.0, type=float)
|
161 |
+
|
162 |
+
parser.add_argument("--no_sort_results", action="store_true",
|
163 |
+
help="do not sort results, use this for moment query visualization")
|
164 |
+
parser.add_argument("--max_before_nms", type=int, default=10)
|
165 |
+
parser.add_argument("--max_after_nms", type=int, default=10)
|
166 |
+
parser.add_argument("--conf_thd", type=float, default=0.0, help="only keep windows with conf >= conf_thd")
|
167 |
+
parser.add_argument("--nms_thd", type=float, default=-1,
|
168 |
+
help="additionally use non-maximum suppression "
|
169 |
+
"(or non-minimum suppression for distance)"
|
170 |
+
"to post-processing the predictions. "
|
171 |
+
"-1: do not use nms. [0, 1]")
|
172 |
+
self.parser = parser
|
173 |
+
|
174 |
+
def display_save(self, opt):
|
175 |
+
args = vars(opt)
|
176 |
+
# Display settings
|
177 |
+
print(dict_to_markdown(vars(opt), max_str_len=120))
|
178 |
+
# Save settings
|
179 |
+
if not isinstance(self, TestOptions):
|
180 |
+
option_file_path = os.path.join(opt.results_dir, self.saved_option_filename) # not yaml file indeed
|
181 |
+
save_json(args, option_file_path, save_pretty=True)
|
182 |
+
|
183 |
+
def parse(self, a_feat_dir=None):
|
184 |
+
if not self.initialized:
|
185 |
+
self.initialize()
|
186 |
+
opt = self.parser.parse_args()
|
187 |
+
|
188 |
+
if opt.debug:
|
189 |
+
opt.results_root = os.path.sep.join(opt.results_root.split(os.path.sep)[:-1] + ["debug_results", ])
|
190 |
+
opt.num_workers = 0
|
191 |
+
|
192 |
+
if isinstance(self, TestOptions):
|
193 |
+
# modify model_dir to absolute path
|
194 |
+
# opt.model_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "results", opt.model_dir)
|
195 |
+
opt.model_dir = os.path.dirname(opt.resume)
|
196 |
+
if a_feat_dir is not None:
|
197 |
+
opt.a_feat_dir = a_feat_dir
|
198 |
+
saved_options = load_json(os.path.join(opt.model_dir, self.saved_option_filename))
|
199 |
+
for arg in saved_options: # use saved options to overwrite all BaseOptions args.
|
200 |
+
if arg not in ["results_root", "num_workers", "nms_thd", "debug", # "max_before_nms", "max_after_nms"
|
201 |
+
"max_pred_l", "min_pred_l",
|
202 |
+
"resume", "resume_all", "no_sort_results"]:
|
203 |
+
setattr(opt, arg, saved_options[arg])
|
204 |
+
# opt.no_core_driver = True
|
205 |
+
if opt.eval_results_dir is not None:
|
206 |
+
opt.results_dir = opt.eval_results_dir
|
207 |
+
else:
|
208 |
+
if opt.exp_id is None:
|
209 |
+
raise ValueError("--exp_id is required for at a training option!")
|
210 |
+
|
211 |
+
ctx_str = opt.ctx_mode + "_sub" if any(["sub_ctx" in p for p in opt.v_feat_dirs]) else opt.ctx_mode
|
212 |
+
opt.results_dir = os.path.join(opt.results_root,
|
213 |
+
"-".join([opt.dset_name, ctx_str, opt.exp_id,
|
214 |
+
str(opt.enc_layers) + str(opt.dec_layers) + str(opt.t2v_layers) + str(opt.moment_layers) + str(opt.dummy_layers) + str(opt.sent_layers),
|
215 |
+
'ndum_' + str(opt.num_dummies), 'nprom_' + str(opt.num_prompts) + '_' + str(opt.total_prompts)]))
|
216 |
+
mkdirp(opt.results_dir)
|
217 |
+
save_fns = ['cg_detr/model.py', 'cg_detr/transformer.py']
|
218 |
+
for save_fn in save_fns:
|
219 |
+
shutil.copyfile(save_fn, os.path.join(opt.results_dir, os.path.basename(save_fn)))
|
220 |
+
|
221 |
+
# save a copy of current code
|
222 |
+
code_dir = os.path.dirname(os.path.realpath(__file__))
|
223 |
+
code_zip_filename = os.path.join(opt.results_dir, "code.zip")
|
224 |
+
make_zipfile(code_dir, code_zip_filename,
|
225 |
+
enclosing_dir="code",
|
226 |
+
exclude_dirs_substring="results",
|
227 |
+
exclude_dirs=["results", "debug_results", "__pycache__"],
|
228 |
+
exclude_extensions=[".pyc", ".ipynb", ".swap"], )
|
229 |
+
|
230 |
+
self.display_save(opt)
|
231 |
+
|
232 |
+
opt.ckpt_filepath = os.path.join(opt.results_dir, self.ckpt_filename)
|
233 |
+
opt.train_log_filepath = os.path.join(opt.results_dir, self.train_log_filename)
|
234 |
+
opt.eval_log_filepath = os.path.join(opt.results_dir, self.eval_log_filename)
|
235 |
+
opt.tensorboard_log_dir = os.path.join(opt.results_dir, self.tensorboard_log_dir)
|
236 |
+
opt.device = torch.device("cuda" if opt.device >= 0 else "cpu")
|
237 |
+
opt.pin_memory = not opt.no_pin_memory
|
238 |
+
|
239 |
+
opt.use_tef = "tef" in opt.ctx_mode
|
240 |
+
opt.use_video = "video" in opt.ctx_mode
|
241 |
+
if not opt.use_video:
|
242 |
+
opt.v_feat_dim = 0
|
243 |
+
if opt.use_tef:
|
244 |
+
opt.v_feat_dim += 2
|
245 |
+
|
246 |
+
self.opt = opt
|
247 |
+
return opt
|
248 |
+
|
249 |
+
|
250 |
+
class TestOptions(BaseOptions):
|
251 |
+
"""add additional options for evaluating"""
|
252 |
+
|
253 |
+
def initialize(self):
|
254 |
+
BaseOptions.initialize(self)
|
255 |
+
# also need to specify --eval_split_name
|
256 |
+
self.parser.add_argument("--eval_id", type=str, help="evaluation id")
|
257 |
+
self.parser.add_argument("--eval_results_dir", type=str, default=None,
|
258 |
+
help="dir to save results, if not set, fall back to training results_dir")
|
259 |
+
self.parser.add_argument("--model_dir", type=str,
|
260 |
+
help="dir contains the model file, will be converted to absolute path afterwards")
|
261 |
+
|
third_party/cgdetr/cg_detr/crossattention.py
ADDED
@@ -0,0 +1,396 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------
|
2 |
+
# DAB-DETR
|
3 |
+
# Copyright (c) 2022 IDEA. All Rights Reserved.
|
4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
5 |
+
# ------------------------------------------------------------------------
|
6 |
+
# Modified from Conditional DETR (https://github.com/Atten4Vis/ConditionalDETR)
|
7 |
+
# Copyright (c) 2021 Microsoft. All Rights Reserved.
|
8 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
9 |
+
# ------------------------------------------------------------------------
|
10 |
+
# Modified from DETR (https://github.com/facebookresearch/detr)
|
11 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
12 |
+
# ------------------------------------------------------------------------
|
13 |
+
# Modified from codes in torch.nn
|
14 |
+
# ------------------------------------------------------------------------
|
15 |
+
|
16 |
+
"""
|
17 |
+
MultiheadAttention that support query, key, and value to have different dimensions.
|
18 |
+
Query, key, and value projections are removed.
|
19 |
+
Mostly copy-paste from https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/activation.py#L873
|
20 |
+
and https://github.com/pytorch/pytorch/blob/master/torch/nn/functional.py#L4837
|
21 |
+
"""
|
22 |
+
|
23 |
+
import copy
|
24 |
+
from typing import Optional, List
|
25 |
+
|
26 |
+
import torch
|
27 |
+
import torch.nn.functional as F
|
28 |
+
from torch import nn, Tensor
|
29 |
+
|
30 |
+
import warnings
|
31 |
+
from typing import Tuple, Optional
|
32 |
+
|
33 |
+
import torch
|
34 |
+
from torch import Tensor
|
35 |
+
from torch.nn.modules.linear import Linear
|
36 |
+
from torch.nn.init import xavier_uniform_
|
37 |
+
from torch.nn.init import constant_
|
38 |
+
from torch.nn.init import xavier_normal_
|
39 |
+
from torch.nn.parameter import Parameter
|
40 |
+
from torch.nn.modules.module import Module
|
41 |
+
from torch.nn import functional as F
|
42 |
+
|
43 |
+
import warnings
|
44 |
+
import math
|
45 |
+
|
46 |
+
from torch._C import _infer_size, _add_docstr
|
47 |
+
from torch.nn import _reduction as _Reduction
|
48 |
+
from torch.nn.modules import utils
|
49 |
+
from torch.nn.modules.utils import _single, _pair, _triple, _list_with_default
|
50 |
+
from torch.nn import grad
|
51 |
+
from torch import _VF
|
52 |
+
from torch._jit_internal import boolean_dispatch, List, Optional, _overload, Tuple
|
53 |
+
try:
|
54 |
+
from torch.overrides import has_torch_function, handle_torch_function
|
55 |
+
except:
|
56 |
+
from torch._overrides import has_torch_function, handle_torch_function
|
57 |
+
Tensor = torch.Tensor
|
58 |
+
|
59 |
+
from torch.nn.functional import linear, pad, softmax, dropout
|
60 |
+
|
61 |
+
class MultiheadAttention(Module):
|
62 |
+
r"""Allows the model to jointly attend to information
|
63 |
+
from different representation subspaces.
|
64 |
+
See reference: Attention Is All You Need
|
65 |
+
.. math::
|
66 |
+
\text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
|
67 |
+
\text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)
|
68 |
+
Args:
|
69 |
+
embed_dim: total dimension of the model.
|
70 |
+
num_heads: parallel attention heads.
|
71 |
+
dropout: a Dropout layer on attn_output_weights. Default: 0.0.
|
72 |
+
bias: add bias as module parameter. Default: True.
|
73 |
+
add_bias_kv: add bias to the key and value sequences at dim=0.
|
74 |
+
add_zero_attn: add a new batch of zeros to the key and
|
75 |
+
value sequences at dim=1.
|
76 |
+
kdim: total number of features in key. Default: None.
|
77 |
+
vdim: total number of features in value. Default: None.
|
78 |
+
Note: if kdim and vdim are None, they will be set to embed_dim such that
|
79 |
+
query, key, and value have the same number of features.
|
80 |
+
Examples::
|
81 |
+
>>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
|
82 |
+
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)
|
83 |
+
"""
|
84 |
+
bias_k: Optional[torch.Tensor]
|
85 |
+
bias_v: Optional[torch.Tensor]
|
86 |
+
|
87 |
+
def __init__(self, embed_dim, num_heads, dropout=0., num_dummies=3, bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None):
|
88 |
+
super(MultiheadAttention, self).__init__()
|
89 |
+
self.num_dummies = num_dummies
|
90 |
+
self.embed_dim = embed_dim
|
91 |
+
self.kdim = kdim if kdim is not None else embed_dim
|
92 |
+
self.vdim = vdim if vdim is not None else embed_dim
|
93 |
+
self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
|
94 |
+
|
95 |
+
self.num_heads = num_heads
|
96 |
+
self.dropout = dropout
|
97 |
+
self.head_dim = embed_dim // num_heads
|
98 |
+
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
|
99 |
+
|
100 |
+
vdim = vdim if vdim is not None else embed_dim
|
101 |
+
self.out_proj = Linear(vdim , vdim)
|
102 |
+
|
103 |
+
self.in_proj_bias = None
|
104 |
+
self.in_proj_weight = None
|
105 |
+
self.bias_k = self.bias_v = None
|
106 |
+
self.q_proj_weight = None
|
107 |
+
self.k_proj_weight = None
|
108 |
+
self.v_proj_weight = None
|
109 |
+
|
110 |
+
self.add_zero_attn = add_zero_attn
|
111 |
+
|
112 |
+
self._reset_parameters()
|
113 |
+
|
114 |
+
def _reset_parameters(self):
|
115 |
+
constant_(self.out_proj.bias, 0.)
|
116 |
+
|
117 |
+
def __setstate__(self, state):
|
118 |
+
# Support loading old MultiheadAttention checkpoints generated by v1.1.0
|
119 |
+
if '_qkv_same_embed_dim' not in state:
|
120 |
+
state['_qkv_same_embed_dim'] = True
|
121 |
+
|
122 |
+
super(MultiheadAttention, self).__setstate__(state)
|
123 |
+
|
124 |
+
def forward(self, query, key, value, key_padding_mask=None,
|
125 |
+
need_weights=True, attn_mask=None, dummy=True):
|
126 |
+
# type: (Tensor, Tensor, Tensor, Optional[Tensor], bool, Optional[Tensor]) -> Tuple[Tensor, Optional[Tensor]]
|
127 |
+
r"""
|
128 |
+
Args:
|
129 |
+
query, key, value: map a query and a set of key-value pairs to an output.
|
130 |
+
See "Attention Is All You Need" for more details.
|
131 |
+
key_padding_mask: if provided, specified padding elements in the key will
|
132 |
+
be ignored by the attention. When given a binary mask and a value is True,
|
133 |
+
the corresponding value on the attention layer will be ignored. When given
|
134 |
+
a byte mask and a value is non-zero, the corresponding value on the attention
|
135 |
+
layer will be ignored
|
136 |
+
need_weights: output attn_output_weights.
|
137 |
+
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
|
138 |
+
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
|
139 |
+
Shape:
|
140 |
+
- Inputs:
|
141 |
+
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
|
142 |
+
the embedding dimension.
|
143 |
+
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
|
144 |
+
the embedding dimension.
|
145 |
+
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
|
146 |
+
the embedding dimension.
|
147 |
+
- key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
|
148 |
+
If a ByteTensor is provided, the non-zero positions will be ignored while the position
|
149 |
+
with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the
|
150 |
+
value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
|
151 |
+
- attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
|
152 |
+
3D mask :math:`(N*\text{num_heads}, L, S)` where N is the batch size, L is the target sequence length,
|
153 |
+
S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked
|
154 |
+
positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
|
155 |
+
while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
|
156 |
+
is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
|
157 |
+
is provided, it will be added to the attention weight.
|
158 |
+
- Outputs:
|
159 |
+
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
|
160 |
+
E is the embedding dimension.
|
161 |
+
- attn_output_weights: :math:`(N, L, S)` where N is the batch size,
|
162 |
+
L is the target sequence length, S is the source sequence length.
|
163 |
+
"""
|
164 |
+
if not self._qkv_same_embed_dim:
|
165 |
+
return multi_head_attention_forward(
|
166 |
+
query, key, value, self.embed_dim, self.num_heads,
|
167 |
+
self.in_proj_weight, self.in_proj_bias,
|
168 |
+
self.bias_k, self.bias_v, self.add_zero_attn,
|
169 |
+
self.dropout, self.out_proj.weight, self.out_proj.bias,
|
170 |
+
training=self.training,
|
171 |
+
key_padding_mask=key_padding_mask, need_weights=need_weights,
|
172 |
+
attn_mask=attn_mask, use_separate_proj_weight=True,
|
173 |
+
q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
|
174 |
+
v_proj_weight=self.v_proj_weight, out_dim=self.vdim, num_dummies=self.num_dummies, dummy=dummy)
|
175 |
+
else:
|
176 |
+
return multi_head_attention_forward(
|
177 |
+
query, key, value, self.embed_dim, self.num_heads,
|
178 |
+
self.in_proj_weight, self.in_proj_bias,
|
179 |
+
self.bias_k, self.bias_v, self.add_zero_attn,
|
180 |
+
self.dropout, self.out_proj.weight, self.out_proj.bias,
|
181 |
+
training=self.training,
|
182 |
+
key_padding_mask=key_padding_mask, need_weights=need_weights,
|
183 |
+
attn_mask=attn_mask, out_dim=self.vdim, num_dummies=self.num_dummies, dummy=dummy)
|
184 |
+
|
185 |
+
|
186 |
+
def multi_head_attention_forward(query: Tensor,
|
187 |
+
key: Tensor,
|
188 |
+
value: Tensor,
|
189 |
+
embed_dim_to_check: int,
|
190 |
+
num_heads: int,
|
191 |
+
in_proj_weight: Tensor,
|
192 |
+
in_proj_bias: Tensor,
|
193 |
+
bias_k: Optional[Tensor],
|
194 |
+
bias_v: Optional[Tensor],
|
195 |
+
add_zero_attn: bool,
|
196 |
+
dropout_p: float,
|
197 |
+
out_proj_weight: Tensor,
|
198 |
+
out_proj_bias: Tensor,
|
199 |
+
training: bool = True,
|
200 |
+
key_padding_mask: Optional[Tensor] = None,
|
201 |
+
need_weights: bool = True,
|
202 |
+
attn_mask: Optional[Tensor] = None,
|
203 |
+
use_separate_proj_weight: bool = False,
|
204 |
+
q_proj_weight: Optional[Tensor] = None,
|
205 |
+
k_proj_weight: Optional[Tensor] = None,
|
206 |
+
v_proj_weight: Optional[Tensor] = None,
|
207 |
+
static_k: Optional[Tensor] = None,
|
208 |
+
static_v: Optional[Tensor] = None,
|
209 |
+
out_dim: Optional[Tensor] = None,
|
210 |
+
num_dummies=3,
|
211 |
+
dummy=True,
|
212 |
+
) -> Tuple[Tensor, Optional[Tensor]]:
|
213 |
+
r"""
|
214 |
+
Args:
|
215 |
+
query, key, value: map a query and a set of key-value pairs to an output.
|
216 |
+
See "Attention Is All You Need" for more details.
|
217 |
+
embed_dim_to_check: total dimension of the model.
|
218 |
+
num_heads: parallel attention heads.
|
219 |
+
in_proj_weight, in_proj_bias: input projection weight and bias.
|
220 |
+
bias_k, bias_v: bias of the key and value sequences to be added at dim=0.
|
221 |
+
add_zero_attn: add a new batch of zeros to the key and
|
222 |
+
value sequences at dim=1.
|
223 |
+
dropout_p: probability of an element to be zeroed.
|
224 |
+
out_proj_weight, out_proj_bias: the output projection weight and bias.
|
225 |
+
training: apply dropout if is ``True``.
|
226 |
+
key_padding_mask: if provided, specified padding elements in the key will
|
227 |
+
be ignored by the attention. This is an binary mask. When the value is True,
|
228 |
+
the corresponding value on the attention layer will be filled with -inf.
|
229 |
+
need_weights: output attn_output_weights.
|
230 |
+
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
|
231 |
+
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
|
232 |
+
use_separate_proj_weight: the function accept the proj. weights for query, key,
|
233 |
+
and value in different forms. If false, in_proj_weight will be used, which is
|
234 |
+
a combination of q_proj_weight, k_proj_weight, v_proj_weight.
|
235 |
+
q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias.
|
236 |
+
static_k, static_v: static key and value used for attention operators.
|
237 |
+
Shape:
|
238 |
+
Inputs:
|
239 |
+
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
|
240 |
+
the embedding dimension.
|
241 |
+
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
|
242 |
+
the embedding dimension.
|
243 |
+
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
|
244 |
+
the embedding dimension.
|
245 |
+
- key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
|
246 |
+
If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions
|
247 |
+
will be unchanged. If a BoolTensor is provided, the positions with the
|
248 |
+
value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
|
249 |
+
- attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
|
250 |
+
3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
|
251 |
+
S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked
|
252 |
+
positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
|
253 |
+
while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
|
254 |
+
are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
|
255 |
+
is provided, it will be added to the attention weight.
|
256 |
+
- static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
|
257 |
+
N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
|
258 |
+
- static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
|
259 |
+
N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
|
260 |
+
Outputs:
|
261 |
+
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
|
262 |
+
E is the embedding dimension.
|
263 |
+
- attn_output_weights: :math:`(N, L, S)` where N is the batch size,
|
264 |
+
L is the target sequence length, S is the source sequence length.
|
265 |
+
"""
|
266 |
+
if not torch.jit.is_scripting():
|
267 |
+
tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v,
|
268 |
+
out_proj_weight, out_proj_bias)
|
269 |
+
if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):
|
270 |
+
return handle_torch_function(
|
271 |
+
multi_head_attention_forward, tens_ops, query, key, value,
|
272 |
+
embed_dim_to_check, num_heads, in_proj_weight, in_proj_bias,
|
273 |
+
bias_k, bias_v, add_zero_attn, dropout_p, out_proj_weight,
|
274 |
+
out_proj_bias, training=training, key_padding_mask=key_padding_mask,
|
275 |
+
need_weights=need_weights, attn_mask=attn_mask,
|
276 |
+
use_separate_proj_weight=use_separate_proj_weight,
|
277 |
+
q_proj_weight=q_proj_weight, k_proj_weight=k_proj_weight,
|
278 |
+
v_proj_weight=v_proj_weight, static_k=static_k, static_v=static_v)
|
279 |
+
tgt_len, bsz, embed_dim = query.size()
|
280 |
+
assert embed_dim == embed_dim_to_check
|
281 |
+
# allow MHA to have different sizes for the feature dimension
|
282 |
+
assert key.size(0) == value.size(0) and key.size(1) == value.size(1)
|
283 |
+
|
284 |
+
head_dim = embed_dim // num_heads
|
285 |
+
v_head_dim = out_dim // num_heads
|
286 |
+
assert head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"
|
287 |
+
scaling = float(head_dim) ** -0.5
|
288 |
+
|
289 |
+
q = query * scaling
|
290 |
+
k = key
|
291 |
+
v = value
|
292 |
+
|
293 |
+
if attn_mask is not None:
|
294 |
+
assert attn_mask.dtype == torch.float32 or attn_mask.dtype == torch.float64 or \
|
295 |
+
attn_mask.dtype == torch.float16 or attn_mask.dtype == torch.uint8 or attn_mask.dtype == torch.bool, \
|
296 |
+
'Only float, byte, and bool types are supported for attn_mask, not {}'.format(attn_mask.dtype)
|
297 |
+
if attn_mask.dtype == torch.uint8:
|
298 |
+
warnings.warn("Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
|
299 |
+
attn_mask = attn_mask.to(torch.bool)
|
300 |
+
|
301 |
+
if attn_mask.dim() == 2:
|
302 |
+
attn_mask = attn_mask.unsqueeze(0)
|
303 |
+
if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
|
304 |
+
raise RuntimeError('The size of the 2D attn_mask is not correct.')
|
305 |
+
elif attn_mask.dim() == 3:
|
306 |
+
if list(attn_mask.size()) != [bsz * num_heads, query.size(0), key.size(0)]:
|
307 |
+
raise RuntimeError('The size of the 3D attn_mask is not correct.')
|
308 |
+
else:
|
309 |
+
raise RuntimeError("attn_mask's dimension {} is not supported".format(attn_mask.dim()))
|
310 |
+
# attn_mask's dim is 3 now.
|
311 |
+
|
312 |
+
# convert ByteTensor key_padding_mask to bool
|
313 |
+
if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
|
314 |
+
warnings.warn("Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
|
315 |
+
key_padding_mask = key_padding_mask.to(torch.bool)
|
316 |
+
|
317 |
+
if bias_k is not None and bias_v is not None:
|
318 |
+
if static_k is None and static_v is None:
|
319 |
+
k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
|
320 |
+
v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
|
321 |
+
if attn_mask is not None:
|
322 |
+
attn_mask = pad(attn_mask, (0, 1))
|
323 |
+
if key_padding_mask is not None:
|
324 |
+
key_padding_mask = pad(key_padding_mask, (0, 1))
|
325 |
+
else:
|
326 |
+
assert static_k is None, "bias cannot be added to static key."
|
327 |
+
assert static_v is None, "bias cannot be added to static value."
|
328 |
+
else:
|
329 |
+
assert bias_k is None
|
330 |
+
assert bias_v is None
|
331 |
+
|
332 |
+
q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
|
333 |
+
if k is not None:
|
334 |
+
k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
|
335 |
+
if v is not None:
|
336 |
+
v = v.contiguous().view(-1, bsz * num_heads, v_head_dim).transpose(0, 1)
|
337 |
+
|
338 |
+
if static_k is not None:
|
339 |
+
assert static_k.size(0) == bsz * num_heads
|
340 |
+
assert static_k.size(2) == head_dim
|
341 |
+
k = static_k
|
342 |
+
|
343 |
+
if static_v is not None:
|
344 |
+
assert static_v.size(0) == bsz * num_heads
|
345 |
+
assert static_v.size(2) == v_head_dim
|
346 |
+
v = static_v
|
347 |
+
|
348 |
+
src_len = k.size(1)
|
349 |
+
|
350 |
+
if key_padding_mask is not None:
|
351 |
+
assert key_padding_mask.size(0) == bsz
|
352 |
+
assert key_padding_mask.size(1) == src_len
|
353 |
+
|
354 |
+
if add_zero_attn:
|
355 |
+
src_len += 1
|
356 |
+
k = torch.cat([k, torch.zeros((k.size(0), 1) + k.size()[2:], dtype=k.dtype, device=k.device)], dim=1)
|
357 |
+
v = torch.cat([v, torch.zeros((v.size(0), 1) + v.size()[2:], dtype=v.dtype, device=v.device)], dim=1)
|
358 |
+
if attn_mask is not None:
|
359 |
+
attn_mask = pad(attn_mask, (0, 1))
|
360 |
+
if key_padding_mask is not None:
|
361 |
+
key_padding_mask = pad(key_padding_mask, (0, 1))
|
362 |
+
|
363 |
+
attn_output_weights = torch.bmm(q, k.transpose(1, 2))
|
364 |
+
assert list(attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len]
|
365 |
+
|
366 |
+
if attn_mask is not None:
|
367 |
+
if attn_mask.dtype == torch.bool:
|
368 |
+
attn_output_weights.masked_fill_(attn_mask, float('-inf'))
|
369 |
+
else:
|
370 |
+
attn_output_weights += attn_mask
|
371 |
+
|
372 |
+
|
373 |
+
if key_padding_mask is not None:
|
374 |
+
attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
|
375 |
+
attn_output_weights = attn_output_weights.masked_fill(
|
376 |
+
key_padding_mask.unsqueeze(1).unsqueeze(2),
|
377 |
+
float('-inf'),
|
378 |
+
)
|
379 |
+
attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, src_len)
|
380 |
+
|
381 |
+
attn_output_weights = softmax(attn_output_weights, dim=-1)
|
382 |
+
attn_output_weights_d = dropout(attn_output_weights, p=dropout_p, training=training)
|
383 |
+
if dummy:
|
384 |
+
attn_output = torch.bmm(attn_output_weights_d[:, :, num_dummies:], v[:, num_dummies:,:])
|
385 |
+
else:
|
386 |
+
attn_output = torch.bmm(attn_output_weights_d, v)
|
387 |
+
assert list(attn_output.size()) == [bsz * num_heads, tgt_len, v_head_dim]
|
388 |
+
attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, out_dim)
|
389 |
+
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
|
390 |
+
|
391 |
+
if need_weights:
|
392 |
+
# average attention weights over heads
|
393 |
+
attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
|
394 |
+
return attn_output, attn_output_weights.sum(dim=1) / num_heads
|
395 |
+
else:
|
396 |
+
return attn_output, None
|
third_party/cgdetr/cg_detr/inference.py
ADDED
@@ -0,0 +1,480 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pprint
|
2 |
+
from tqdm import tqdm, trange
|
3 |
+
import numpy as np
|
4 |
+
import os
|
5 |
+
from collections import OrderedDict, defaultdict
|
6 |
+
from utils.basic_utils import AverageMeter
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn.functional as F
|
10 |
+
import torch.backends.cudnn as cudnn
|
11 |
+
from torch.utils.data import DataLoader
|
12 |
+
|
13 |
+
from cg_detr.config import TestOptions
|
14 |
+
from cg_detr.model import build_model
|
15 |
+
from cg_detr.span_utils import span_cxw_to_xx
|
16 |
+
from cg_detr.start_end_dataset import StartEndDataset, start_end_collate, prepare_batch_inputs
|
17 |
+
from cg_detr.postprocessing_cg_detr import PostProcessorDETR
|
18 |
+
from standalone_eval.eval import eval_submission
|
19 |
+
from utils.basic_utils import save_jsonl, save_json
|
20 |
+
from utils.temporal_nms import temporal_nms
|
21 |
+
|
22 |
+
import logging
|
23 |
+
|
24 |
+
logger = logging.getLogger(__name__)
|
25 |
+
logging.basicConfig(format="%(asctime)s.%(msecs)03d:%(levelname)s:%(name)s - %(message)s",
|
26 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
27 |
+
level=logging.INFO)
|
28 |
+
|
29 |
+
|
30 |
+
def post_processing_mr_nms(mr_res, nms_thd, max_before_nms, max_after_nms):
|
31 |
+
mr_res_after_nms = []
|
32 |
+
for e in mr_res:
|
33 |
+
e["pred_relevant_windows"] = temporal_nms(
|
34 |
+
e["pred_relevant_windows"][:max_before_nms],
|
35 |
+
nms_thd=nms_thd,
|
36 |
+
max_after_nms=max_after_nms
|
37 |
+
)
|
38 |
+
mr_res_after_nms.append(e)
|
39 |
+
return mr_res_after_nms
|
40 |
+
|
41 |
+
|
42 |
+
def eval_epoch_post_processing(submission, opt, gt_data, save_submission_filename):
|
43 |
+
# IOU_THDS = (0.5, 0.7)
|
44 |
+
logger.info("Saving/Evaluating before nms results")
|
45 |
+
submission_path = os.path.join(opt.results_dir, save_submission_filename)
|
46 |
+
save_jsonl(submission, submission_path)
|
47 |
+
|
48 |
+
if opt.eval_split_name in ["val"]: # since test_public has no GT
|
49 |
+
metrics = eval_submission(
|
50 |
+
submission, gt_data,
|
51 |
+
verbose=opt.debug, match_number=not opt.debug
|
52 |
+
)
|
53 |
+
save_metrics_path = submission_path.replace(".jsonl", "_metrics.json")
|
54 |
+
save_json(metrics, save_metrics_path, save_pretty=True, sort_keys=False)
|
55 |
+
latest_file_paths = [submission_path, save_metrics_path]
|
56 |
+
else:
|
57 |
+
metrics = None
|
58 |
+
latest_file_paths = [submission_path, ]
|
59 |
+
|
60 |
+
if opt.nms_thd != -1:
|
61 |
+
logger.info("[MR] Performing nms with nms_thd {}".format(opt.nms_thd))
|
62 |
+
submission_after_nms = post_processing_mr_nms(
|
63 |
+
submission, nms_thd=opt.nms_thd,
|
64 |
+
max_before_nms=opt.max_before_nms, max_after_nms=opt.max_after_nms
|
65 |
+
)
|
66 |
+
|
67 |
+
logger.info("Saving/Evaluating nms results")
|
68 |
+
submission_nms_path = submission_path.replace(".jsonl", "_nms_thd_{}.jsonl".format(opt.nms_thd))
|
69 |
+
save_jsonl(submission_after_nms, submission_nms_path)
|
70 |
+
if opt.eval_split_name == "val":
|
71 |
+
metrics_nms = eval_submission(
|
72 |
+
submission_after_nms, gt_data,
|
73 |
+
verbose=opt.debug, match_number=not opt.debug
|
74 |
+
)
|
75 |
+
save_metrics_nms_path = submission_nms_path.replace(".jsonl", "_metrics.json")
|
76 |
+
save_json(metrics_nms, save_metrics_nms_path, save_pretty=True, sort_keys=False)
|
77 |
+
latest_file_paths += [submission_nms_path, save_metrics_nms_path]
|
78 |
+
else:
|
79 |
+
metrics_nms = None
|
80 |
+
latest_file_paths = [submission_nms_path, ]
|
81 |
+
else:
|
82 |
+
metrics_nms = None
|
83 |
+
return metrics, metrics_nms, latest_file_paths
|
84 |
+
|
85 |
+
|
86 |
+
# for HL
|
87 |
+
@torch.no_grad()
|
88 |
+
def compute_hl_results(model, eval_loader, opt, epoch_i=None, criterion=None, tb_writer=None):
|
89 |
+
model.eval()
|
90 |
+
if criterion:
|
91 |
+
assert eval_loader.dataset.load_labels
|
92 |
+
criterion.eval()
|
93 |
+
|
94 |
+
loss_meters = defaultdict(AverageMeter)
|
95 |
+
write_tb = tb_writer is not None and epoch_i is not None
|
96 |
+
|
97 |
+
mr_res = []
|
98 |
+
|
99 |
+
topk = 5 # top-5 map
|
100 |
+
|
101 |
+
video_ap_collected = []
|
102 |
+
for batch in tqdm(eval_loader, desc="compute st ed scores"):
|
103 |
+
query_meta = batch[0]
|
104 |
+
|
105 |
+
model_inputs, targets = prepare_batch_inputs(batch[1], opt.device, non_blocking=opt.pin_memory)
|
106 |
+
|
107 |
+
outputs = model(**model_inputs)
|
108 |
+
|
109 |
+
# loss meters
|
110 |
+
# if criterion:
|
111 |
+
# loss_dict = criterion(outputs, targets)
|
112 |
+
# weight_dict = criterion.weight_dict
|
113 |
+
# print(loss_dict)
|
114 |
+
# print(weight_dict)
|
115 |
+
# print('#######')
|
116 |
+
# {'loss_saliency': tensor(18.1374, device='cuda:0')}
|
117 |
+
# {'loss_span': 10, 'loss_giou': 1, 'loss_label': 4, 'loss_saliency': 1.0, 'loss_ms_align': 1.0,
|
118 |
+
# 'loss_distill': 1.0, 'loss_span_0': 10, 'loss_giou_0': 1, 'loss_label_0': 4, 'loss_ms_align_0': 1.0,
|
119 |
+
# 'loss_distill_0': 1.0}
|
120 |
+
# losses=0.
|
121 |
+
# print(loss_dict.keys(), weight_dict.keys())
|
122 |
+
# losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)
|
123 |
+
# loss_dict["loss_overall"] = float(losses) # for logging only
|
124 |
+
# print(loss_dict.items())
|
125 |
+
#
|
126 |
+
# print(weight_dict.items())
|
127 |
+
# for k, v in loss_dict.items():
|
128 |
+
# loss_meters[k].update(float(v) * weight_dict[k] if k in weight_dict else float(v))
|
129 |
+
|
130 |
+
|
131 |
+
preds = outputs['saliency_scores'].clone().detach()
|
132 |
+
|
133 |
+
for meta, pred in zip(query_meta, preds):
|
134 |
+
pred = pred
|
135 |
+
label = meta['label'] # raw label
|
136 |
+
|
137 |
+
video_ap = []
|
138 |
+
# Follow the UMT code "https://github.com/TencentARC/UMT/blob/main/datasets/tvsum.py"
|
139 |
+
|
140 |
+
if opt.dset_name in ["tvsum"]:
|
141 |
+
for i in range(20):
|
142 |
+
pred=pred.cpu()
|
143 |
+
cur_pred = pred[:len(label)]
|
144 |
+
inds = torch.argsort(cur_pred, descending=True, dim=-1)
|
145 |
+
|
146 |
+
# video_id = self.get_video_id(idx)
|
147 |
+
cur_label = torch.Tensor(label)[:, i]
|
148 |
+
cur_label = torch.where(cur_label > cur_label.median(), 1.0, .0)
|
149 |
+
|
150 |
+
cur_label = cur_label[inds].tolist()[:topk]
|
151 |
+
|
152 |
+
# if (num_gt := sum(cur_label)) == 0:
|
153 |
+
num_gt = sum(cur_label)
|
154 |
+
if num_gt == 0:
|
155 |
+
video_ap.append(0)
|
156 |
+
continue
|
157 |
+
|
158 |
+
hits = ap = rec = 0
|
159 |
+
prc = 1
|
160 |
+
|
161 |
+
for j, gt in enumerate(cur_label):
|
162 |
+
hits += gt
|
163 |
+
|
164 |
+
_rec = hits / num_gt
|
165 |
+
_prc = hits / (j + 1)
|
166 |
+
|
167 |
+
ap += (_rec - rec) * (prc + _prc) / 2
|
168 |
+
rec, prc = _rec, _prc
|
169 |
+
|
170 |
+
video_ap.append(ap)
|
171 |
+
|
172 |
+
elif opt.dset_name in ["youtube_uni"]:
|
173 |
+
cur_pred = pred[:len(label)]
|
174 |
+
# if opt.dset_name == "tvsum_sfc":
|
175 |
+
cur_pred = cur_pred.cpu()
|
176 |
+
inds = torch.argsort(cur_pred, descending=True, dim=-1)
|
177 |
+
|
178 |
+
|
179 |
+
cur_label = torch.Tensor(label).squeeze()[inds].tolist()
|
180 |
+
|
181 |
+
num_gt = sum(cur_label)
|
182 |
+
if num_gt == 0:
|
183 |
+
video_ap.append(0)
|
184 |
+
continue
|
185 |
+
|
186 |
+
hits = ap = rec = 0
|
187 |
+
prc = 1
|
188 |
+
|
189 |
+
for j, gt in enumerate(cur_label):
|
190 |
+
hits += gt
|
191 |
+
|
192 |
+
_rec = hits / num_gt
|
193 |
+
_prc = hits / (j + 1)
|
194 |
+
|
195 |
+
ap += (_rec - rec) * (prc + _prc) / 2
|
196 |
+
rec, prc = _rec, _prc
|
197 |
+
|
198 |
+
video_ap.append(float(ap))
|
199 |
+
else:
|
200 |
+
print("No such dataset")
|
201 |
+
exit(-1)
|
202 |
+
|
203 |
+
video_ap_collected.append(video_ap)
|
204 |
+
|
205 |
+
mean_ap = np.mean(video_ap_collected)
|
206 |
+
submmission = dict(mAP=round(mean_ap, 5))
|
207 |
+
|
208 |
+
|
209 |
+
# tensorboard writer
|
210 |
+
if write_tb and criterion:
|
211 |
+
for k, v in loss_meters.items():
|
212 |
+
tb_writer.add_scalar("Eval/{}".format(k), v.avg, epoch_i + 1)
|
213 |
+
|
214 |
+
return submmission, loss_meters
|
215 |
+
|
216 |
+
|
217 |
+
|
218 |
+
@torch.no_grad()
|
219 |
+
def compute_mr_results(model, eval_loader, opt, epoch_i=None, criterion=None, tb_writer=None):
|
220 |
+
model.eval()
|
221 |
+
if criterion:
|
222 |
+
assert eval_loader.dataset.load_labels
|
223 |
+
criterion.eval()
|
224 |
+
|
225 |
+
loss_meters = defaultdict(AverageMeter)
|
226 |
+
write_tb = tb_writer is not None and epoch_i is not None
|
227 |
+
|
228 |
+
mr_res = []
|
229 |
+
for batch in tqdm(eval_loader, desc="compute st ed scores"):
|
230 |
+
query_meta = batch[0]
|
231 |
+
|
232 |
+
model_inputs, targets = prepare_batch_inputs(batch[1], opt.device, non_blocking=opt.pin_memory)
|
233 |
+
|
234 |
+
outputs = model(**model_inputs)
|
235 |
+
prob = F.softmax(outputs["pred_logits"], -1) # (batch_size, #queries, #classes=2)
|
236 |
+
if opt.span_loss_type == "l1":
|
237 |
+
scores = prob[..., 0] # * (batch_size, #queries) foreground label is 0, we directly take it
|
238 |
+
pred_spans = outputs["pred_spans"] # (bsz, #queries, 2)
|
239 |
+
_saliency_scores = outputs["saliency_scores"].half() # (bsz, L)
|
240 |
+
saliency_scores = []
|
241 |
+
valid_vid_lengths = model_inputs["src_vid_mask"].sum(1).cpu().tolist()
|
242 |
+
for j in range(len(valid_vid_lengths)):
|
243 |
+
saliency_scores.append(_saliency_scores[j, :int(valid_vid_lengths[j])].tolist())
|
244 |
+
else:
|
245 |
+
bsz, n_queries = outputs["pred_spans"].shape[:2] # # (bsz, #queries, max_v_l *2)
|
246 |
+
pred_spans_logits = outputs["pred_spans"].view(bsz, n_queries, 2, opt.max_v_l)
|
247 |
+
pred_span_scores, pred_spans = F.softmax(pred_spans_logits, dim=-1).max(-1) # 2 * (bsz, #queries, 2)
|
248 |
+
scores = torch.prod(pred_span_scores, 2) # (bsz, #queries)
|
249 |
+
pred_spans[:, 1] += 1
|
250 |
+
pred_spans *= opt.clip_length
|
251 |
+
|
252 |
+
# compose predictions
|
253 |
+
for idx, (meta, spans, score) in enumerate(zip(query_meta, pred_spans.cpu(), scores.cpu())):
|
254 |
+
if opt.span_loss_type == "l1":
|
255 |
+
spans = span_cxw_to_xx(spans) * meta["duration"]
|
256 |
+
spans = torch.clamp(spans, 0, meta["duration"])
|
257 |
+
# # (#queries, 3), [st(float), ed(float), score(float)]
|
258 |
+
cur_ranked_preds = torch.cat([spans, score[:, None]], dim=1).tolist()
|
259 |
+
if not opt.no_sort_results:
|
260 |
+
cur_ranked_preds = sorted(cur_ranked_preds, key=lambda x: x[2], reverse=True)
|
261 |
+
cur_ranked_preds = [[float(f"{e:.4f}") for e in row] for row in cur_ranked_preds]
|
262 |
+
cur_query_pred = dict(
|
263 |
+
qid=meta["qid"],
|
264 |
+
query=meta["query"],
|
265 |
+
vid=meta["vid"],
|
266 |
+
pred_relevant_windows=cur_ranked_preds,
|
267 |
+
pred_saliency_scores=saliency_scores[idx]
|
268 |
+
)
|
269 |
+
mr_res.append(cur_query_pred)
|
270 |
+
|
271 |
+
if criterion:
|
272 |
+
loss_dict = criterion(outputs, targets)
|
273 |
+
weight_dict = criterion.weight_dict
|
274 |
+
losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)
|
275 |
+
loss_dict["loss_overall"] = float(losses) # for logging only
|
276 |
+
for k, v in loss_dict.items():
|
277 |
+
loss_meters[k].update(float(v) * weight_dict[k] if k in weight_dict else float(v))
|
278 |
+
|
279 |
+
if opt.debug:
|
280 |
+
break
|
281 |
+
|
282 |
+
if write_tb and criterion:
|
283 |
+
for k, v in loss_meters.items():
|
284 |
+
tb_writer.add_scalar("Eval/{}".format(k), v.avg, epoch_i + 1)
|
285 |
+
|
286 |
+
if opt.dset_name in ['hl']:
|
287 |
+
post_processor = PostProcessorDETR(
|
288 |
+
clip_length=opt.clip_length, min_ts_val=0, max_ts_val=150,
|
289 |
+
min_w_l=2, max_w_l=150, move_window_method="left",
|
290 |
+
process_func_names=("clip_ts", "round_multiple")
|
291 |
+
)
|
292 |
+
elif opt.dset_name in ['charadesSTA']:
|
293 |
+
if opt.v_feat_dim == 4096: # vgg
|
294 |
+
post_processor = PostProcessorDETR(
|
295 |
+
clip_length=opt.clip_length, min_ts_val=0, max_ts_val=360,
|
296 |
+
min_w_l=12, max_w_l=360, move_window_method="left",
|
297 |
+
process_func_names=("clip_ts", "round_multiple")
|
298 |
+
)
|
299 |
+
else:
|
300 |
+
post_processor = PostProcessorDETR(
|
301 |
+
clip_length=opt.clip_length, min_ts_val=0, max_ts_val=150,
|
302 |
+
min_w_l=2, max_w_l=60, move_window_method="left",
|
303 |
+
process_func_names=("clip_ts", "round_multiple")
|
304 |
+
)
|
305 |
+
else:
|
306 |
+
post_processor = PostProcessorDETR(
|
307 |
+
clip_length=opt.clip_length, min_ts_val=0, max_ts_val=50000,
|
308 |
+
min_w_l=0, max_w_l=50000, move_window_method="left",
|
309 |
+
process_func_names=(["round_multiple"])
|
310 |
+
)
|
311 |
+
|
312 |
+
mr_res = post_processor(mr_res)
|
313 |
+
return mr_res, loss_meters
|
314 |
+
|
315 |
+
|
316 |
+
def get_eval_res(model, eval_loader, opt, epoch_i, criterion, tb_writer):
|
317 |
+
"""compute and save query and video proposal embeddings"""
|
318 |
+
eval_res, eval_loss_meters = compute_mr_results(model, eval_loader, opt, epoch_i, criterion, tb_writer) # list(dict)
|
319 |
+
return eval_res, eval_loss_meters
|
320 |
+
|
321 |
+
|
322 |
+
def eval_epoch(model, eval_dataset, opt, save_submission_filename, epoch_i=None, criterion=None, tb_writer=None):
|
323 |
+
logger.info("Generate submissions")
|
324 |
+
model.eval()
|
325 |
+
if criterion is not None and eval_dataset.load_labels:
|
326 |
+
criterion.eval()
|
327 |
+
else:
|
328 |
+
criterion = None
|
329 |
+
|
330 |
+
if opt.dset_name == 'tacos':
|
331 |
+
shuffle = True
|
332 |
+
else:
|
333 |
+
shuffle = False
|
334 |
+
|
335 |
+
eval_loader = DataLoader(
|
336 |
+
eval_dataset,
|
337 |
+
collate_fn=start_end_collate,
|
338 |
+
batch_size=opt.eval_bsz,
|
339 |
+
num_workers=opt.num_workers,
|
340 |
+
shuffle=shuffle,
|
341 |
+
pin_memory=opt.pin_memory
|
342 |
+
)
|
343 |
+
|
344 |
+
|
345 |
+
# tvsum
|
346 |
+
if opt.dset_name in ['tvsum', 'youtube_uni']:
|
347 |
+
metrics, eval_loss_meters = compute_hl_results(model, eval_loader, opt, epoch_i, criterion, tb_writer)
|
348 |
+
|
349 |
+
# to match original save format
|
350 |
+
submission = [
|
351 |
+
{"brief": metrics}
|
352 |
+
]
|
353 |
+
submission_path = os.path.join(opt.results_dir, "latest_metric.jsonl")
|
354 |
+
save_jsonl(submission, submission_path)
|
355 |
+
|
356 |
+
return submission[0], submission[0], eval_loss_meters, [submission_path]
|
357 |
+
|
358 |
+
else:
|
359 |
+
submission, eval_loss_meters = get_eval_res(model, eval_loader, opt, epoch_i, criterion, tb_writer)
|
360 |
+
|
361 |
+
if opt.dset_name in ['charadesSTA', 'tacos', 'nlq']:
|
362 |
+
new_submission = []
|
363 |
+
for s in submission:
|
364 |
+
s.pop('pred_saliency_scores', None)
|
365 |
+
new_submission.append(s)
|
366 |
+
submission = new_submission
|
367 |
+
|
368 |
+
if opt.no_sort_results:
|
369 |
+
save_submission_filename = save_submission_filename.replace(".jsonl", "_unsorted.jsonl")
|
370 |
+
metrics, metrics_nms, latest_file_paths = eval_epoch_post_processing(
|
371 |
+
submission, opt, eval_dataset.data, save_submission_filename)
|
372 |
+
return metrics, metrics_nms, eval_loss_meters, latest_file_paths
|
373 |
+
|
374 |
+
|
375 |
+
def setup_model(opt):
|
376 |
+
"""setup model/optimizer/scheduler and load checkpoints when needed"""
|
377 |
+
logger.info("setup model/optimizer/scheduler")
|
378 |
+
model, criterion = build_model(opt)
|
379 |
+
if opt.device.type == "cuda":
|
380 |
+
logger.info("CUDA enabled.")
|
381 |
+
model.to(opt.device)
|
382 |
+
criterion.to(opt.device)
|
383 |
+
|
384 |
+
param_dicts = [{"params": [p for n, p in model.named_parameters() if p.requires_grad]}]
|
385 |
+
optimizer = torch.optim.AdamW(param_dicts, lr=opt.lr, weight_decay=opt.wd)
|
386 |
+
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, opt.lr_drop)
|
387 |
+
|
388 |
+
if opt.resume is not None:
|
389 |
+
logger.info(f"Load checkpoint from {opt.resume}")
|
390 |
+
checkpoint = torch.load(opt.resume, map_location="cpu")
|
391 |
+
from collections import OrderedDict
|
392 |
+
new_state_dict = OrderedDict()
|
393 |
+
if 'pt' in opt.resume[:-4]:
|
394 |
+
if 'asr' in opt.resume[:25]:
|
395 |
+
model.load_state_dict(checkpoint["model"])
|
396 |
+
else:
|
397 |
+
for k, v in checkpoint["model"].items():
|
398 |
+
name = k[7:] # remove `module.`
|
399 |
+
new_state_dict[name] = v
|
400 |
+
# model.load_state_dict(checkpoint["model"])
|
401 |
+
model.load_state_dict(new_state_dict)
|
402 |
+
else:
|
403 |
+
model.load_state_dict(checkpoint["model"])
|
404 |
+
if opt.resume_all:
|
405 |
+
optimizer.load_state_dict(checkpoint['optimizer'])
|
406 |
+
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
|
407 |
+
opt.start_epoch = checkpoint['epoch'] + 1
|
408 |
+
logger.info(f"Loaded model saved at epoch {checkpoint['epoch']} from checkpoint: {opt.resume}")
|
409 |
+
else:
|
410 |
+
logger.warning("If you intend to evaluate the model, please specify --resume with ckpt path")
|
411 |
+
|
412 |
+
return model, criterion, optimizer, lr_scheduler
|
413 |
+
|
414 |
+
|
415 |
+
def start_inference(train_opt=None, split=None, splitfile=None):
|
416 |
+
if train_opt is not None:
|
417 |
+
opt = TestOptions().parse(train_opt.a_feat_dir)
|
418 |
+
else:
|
419 |
+
opt = TestOptions().parse()
|
420 |
+
if split is not None:
|
421 |
+
opt.eval_split_name = split
|
422 |
+
if splitfile is not None:
|
423 |
+
opt.eval_path = splitfile
|
424 |
+
|
425 |
+
print(opt.eval_split_name)
|
426 |
+
print(opt.eval_path)
|
427 |
+
logger.info("Setup config, data and model...")
|
428 |
+
|
429 |
+
|
430 |
+
cudnn.benchmark = True
|
431 |
+
cudnn.deterministic = False
|
432 |
+
|
433 |
+
assert opt.eval_path is not None
|
434 |
+
if opt.eval_split_name == 'val':
|
435 |
+
loadlabel = True
|
436 |
+
else:
|
437 |
+
loadlabel = False
|
438 |
+
|
439 |
+
eval_dataset = StartEndDataset(
|
440 |
+
dset_name=opt.dset_name,
|
441 |
+
data_path=opt.eval_path,
|
442 |
+
v_feat_dirs=opt.v_feat_dirs,
|
443 |
+
q_feat_dir=opt.t_feat_dir,
|
444 |
+
q_feat_type="last_hidden_state",
|
445 |
+
max_q_l=opt.max_q_l,
|
446 |
+
max_v_l=opt.max_v_l,
|
447 |
+
ctx_mode=opt.ctx_mode,
|
448 |
+
data_ratio=opt.data_ratio,
|
449 |
+
normalize_v=not opt.no_norm_vfeat,
|
450 |
+
normalize_t=not opt.no_norm_tfeat,
|
451 |
+
clip_len=opt.clip_length,
|
452 |
+
max_windows=opt.max_windows,
|
453 |
+
load_labels=loadlabel, # opt.eval_split_name == "val",
|
454 |
+
span_loss_type=opt.span_loss_type,
|
455 |
+
txt_drop_ratio=0,
|
456 |
+
dset_domain=opt.dset_domain,
|
457 |
+
)
|
458 |
+
|
459 |
+
|
460 |
+
|
461 |
+
model, criterion, _, _ = setup_model(opt)
|
462 |
+
|
463 |
+
save_submission_filename = "hl_{}_submission.jsonl".format(
|
464 |
+
opt.eval_split_name)
|
465 |
+
# save_submission_filename = "inference_{}_{}_{}_preds.jsonl".format(
|
466 |
+
# opt.dset_name, opt.eval_split_name, opt.eval_id)
|
467 |
+
logger.info("Starting inference...")
|
468 |
+
with torch.no_grad():
|
469 |
+
metrics_no_nms, metrics_nms, eval_loss_meters, latest_file_paths = \
|
470 |
+
eval_epoch(model, eval_dataset, opt, save_submission_filename, criterion=criterion)
|
471 |
+
if opt.eval_split_name == 'val':
|
472 |
+
logger.info("metrics_no_nms {}".format(pprint.pformat(metrics_no_nms["brief"], indent=4)))
|
473 |
+
if metrics_nms is not None:
|
474 |
+
logger.info("metrics_nms {}".format(pprint.pformat(metrics_nms["brief"], indent=4)))
|
475 |
+
|
476 |
+
from sys import argv
|
477 |
+
if __name__ == '__main__':
|
478 |
+
_,_,_,_,split,_,splitfile = argv
|
479 |
+
|
480 |
+
start_inference(split=split, splitfile=splitfile)
|
third_party/cgdetr/cg_detr/matcher.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
2 |
+
"""
|
3 |
+
Modules to compute the matching cost and solve the corresponding LSAP.
|
4 |
+
"""
|
5 |
+
import torch
|
6 |
+
from scipy.optimize import linear_sum_assignment
|
7 |
+
from torch import nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from third_party.cgdetr.cg_detr.span_utils import generalized_temporal_iou, span_cxw_to_xx
|
10 |
+
|
11 |
+
|
12 |
+
class HungarianMatcher(nn.Module):
|
13 |
+
"""This class computes an assignment between the targets and the predictions of the network
|
14 |
+
|
15 |
+
For efficiency reasons, the targets don't include the no_object. Because of this, in general,
|
16 |
+
there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
|
17 |
+
while the others are un-matched (and thus treated as non-objects).
|
18 |
+
"""
|
19 |
+
def __init__(self, cost_class: float = 1, cost_span: float = 1, cost_giou: float = 1,
|
20 |
+
span_loss_type: str = "l1", max_v_l: int = 75):
|
21 |
+
"""Creates the matcher
|
22 |
+
|
23 |
+
Params:
|
24 |
+
cost_span: This is the relative weight of the L1 error of the span coordinates in the matching cost
|
25 |
+
cost_giou: This is the relative weight of the giou loss of the spans in the matching cost
|
26 |
+
"""
|
27 |
+
super().__init__()
|
28 |
+
self.cost_class = cost_class
|
29 |
+
self.cost_span = cost_span
|
30 |
+
self.cost_giou = cost_giou
|
31 |
+
self.span_loss_type = span_loss_type
|
32 |
+
self.max_v_l = max_v_l
|
33 |
+
self.foreground_label = 0
|
34 |
+
assert cost_class != 0 or cost_span != 0 or cost_giou != 0, "all costs cant be 0"
|
35 |
+
|
36 |
+
@torch.no_grad()
|
37 |
+
def forward(self, outputs, targets):
|
38 |
+
""" Performs the matching
|
39 |
+
|
40 |
+
Params:
|
41 |
+
outputs: This is a dict that contains at least these entries:
|
42 |
+
"pred_spans": Tensor of dim [batch_size, num_queries, 2] with the predicted span coordinates,
|
43 |
+
in normalized (cx, w) format
|
44 |
+
""pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
|
45 |
+
|
46 |
+
targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:
|
47 |
+
"spans": Tensor of dim [num_target_spans, 2] containing the target span coordinates. The spans are
|
48 |
+
in normalized (cx, w) format
|
49 |
+
|
50 |
+
Returns:
|
51 |
+
A list of size batch_size, containing tuples of (index_i, index_j) where:
|
52 |
+
- index_i is the indices of the selected predictions (in order)
|
53 |
+
- index_j is the indices of the corresponding selected targets (in order)
|
54 |
+
For each batch element, it holds:
|
55 |
+
len(index_i) = len(index_j) = min(num_queries, num_target_spans)
|
56 |
+
"""
|
57 |
+
bs, num_queries = outputs["pred_spans"].shape[:2]
|
58 |
+
targets = targets["span_labels"]
|
59 |
+
# import pdb; pdb.set_trace()
|
60 |
+
|
61 |
+
# Also concat the target labels and spans
|
62 |
+
out_prob = outputs["pred_logits"].flatten(0, 1).softmax(-1) # [batch_size * num_queries, num_classes]
|
63 |
+
tgt_spans = torch.cat([v["spans"] for v in targets]) # [num_target_spans in batch, 2]
|
64 |
+
tgt_ids = torch.full([len(tgt_spans)], self.foreground_label) # [total #spans in the batch]
|
65 |
+
|
66 |
+
# Compute the classification cost. Contrary to the loss, we don't use the NLL,
|
67 |
+
# but approximate it in 1 - prob[target class].
|
68 |
+
# The 1 is a constant that doesn't change the matching, it can be omitted.
|
69 |
+
cost_class = -out_prob[:, tgt_ids] # [batch_size * num_queries, total #spans in the batch]
|
70 |
+
|
71 |
+
if self.span_loss_type == "l1":
|
72 |
+
# We flatten to compute the cost matrices in a batch
|
73 |
+
out_spans = outputs["pred_spans"].flatten(0, 1) # [batch_size * num_queries, 2]
|
74 |
+
|
75 |
+
# Compute the L1 cost between spans
|
76 |
+
cost_span = torch.cdist(out_spans.type(torch.float32), tgt_spans.type(torch.float32), p=1) # [batch_size * num_queries, total #spans in the batch]
|
77 |
+
cost_span = cost_span.type(torch.bfloat16)
|
78 |
+
|
79 |
+
# Compute the giou cost between spans
|
80 |
+
# [batch_size * num_queries, total #spans in the batch]
|
81 |
+
cost_giou = - generalized_temporal_iou(span_cxw_to_xx(out_spans), span_cxw_to_xx(tgt_spans))
|
82 |
+
else:
|
83 |
+
pred_spans = outputs["pred_spans"] # (bsz, #queries, max_v_l * 2)
|
84 |
+
pred_spans = pred_spans.view(bs * num_queries, 2, self.max_v_l).softmax(-1) # (bsz * #queries, 2, max_v_l)
|
85 |
+
cost_span = - pred_spans[:, 0][:, tgt_spans[:, 0]] - \
|
86 |
+
pred_spans[:, 1][:, tgt_spans[:, 1]] # (bsz * #queries, #spans)
|
87 |
+
# pred_spans = pred_spans.repeat(1, n_spans, 1, 1).flatten(0, 1) # (bsz * #queries * #spans, max_v_l, 2)
|
88 |
+
# tgt_spans = tgt_spans.view(1, n_spans, 2).repeat(bs * num_queries, 1, 1).flatten(0, 1) # (bsz * #queries * #spans, 2)
|
89 |
+
# cost_span = pred_spans[tgt_spans]
|
90 |
+
# cost_span = cost_span.view(bs * num_queries, n_spans)
|
91 |
+
|
92 |
+
# giou
|
93 |
+
cost_giou = 0
|
94 |
+
|
95 |
+
# Final cost matrix
|
96 |
+
# import ipdb; ipdb.set_trace()
|
97 |
+
C = self.cost_span * cost_span + self.cost_giou * cost_giou + self.cost_class * cost_class
|
98 |
+
C = C.view(bs, num_queries, -1).cpu()
|
99 |
+
|
100 |
+
sizes = [len(v["spans"]) for v in targets]
|
101 |
+
indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))]
|
102 |
+
return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]
|
103 |
+
|
104 |
+
|
105 |
+
def build_matcher(args):
|
106 |
+
return HungarianMatcher(
|
107 |
+
cost_span=args.set_cost_span, cost_giou=args.set_cost_giou,
|
108 |
+
cost_class=args.set_cost_class, span_loss_type=args.span_loss_type, max_v_l=args.max_v_l
|
109 |
+
)
|
third_party/cgdetr/cg_detr/misc.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
@torch.no_grad()
|
5 |
+
def accuracy(output, target, topk=(1,)):
|
6 |
+
"""Computes the precision@k for the specified values of k
|
7 |
+
output: (#items, #classes)
|
8 |
+
target: int,
|
9 |
+
"""
|
10 |
+
maxk = max(topk)
|
11 |
+
num_items = output.size(0)
|
12 |
+
|
13 |
+
_, pred = output.topk(maxk, 1, True, True)
|
14 |
+
pred = pred.t()
|
15 |
+
correct = pred.eq(target)
|
16 |
+
|
17 |
+
res = []
|
18 |
+
for k in topk:
|
19 |
+
correct_k = correct[:k].view(-1).float().sum(0)
|
20 |
+
res.append(correct_k.mul_(100.0 / num_items))
|
21 |
+
return res
|
third_party/cgdetr/cg_detr/model.py
ADDED
@@ -0,0 +1,1178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
2 |
+
"""
|
3 |
+
CG-DETR model and criterion classes.
|
4 |
+
"""
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from torch import nn
|
8 |
+
|
9 |
+
from third_party.cgdetr.cg_detr.span_utils import generalized_temporal_iou, span_cxw_to_xx
|
10 |
+
|
11 |
+
from third_party.cgdetr.cg_detr.matcher import build_matcher
|
12 |
+
from third_party.cgdetr.cg_detr.transformer import build_transformer, TransformerEncoderLayer, TransformerEncoder
|
13 |
+
from third_party.cgdetr.cg_detr.position_encoding import build_position_encoding
|
14 |
+
from third_party.cgdetr.cg_detr.misc import accuracy
|
15 |
+
import numpy as np
|
16 |
+
import copy
|
17 |
+
|
18 |
+
def inverse_sigmoid(x, eps=1e-3):
|
19 |
+
x = x.clamp(min=0, max=1)
|
20 |
+
x1 = x.clamp(min=eps)
|
21 |
+
x2 = (1 - x).clamp(min=eps)
|
22 |
+
return torch.log(x1/x2)
|
23 |
+
|
24 |
+
def init_weights(module):
|
25 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
26 |
+
module.weight.data.normal_(mean=0.0, std=0.02)
|
27 |
+
elif isinstance(module, nn.LayerNorm):
|
28 |
+
module.bias.data.zero_()
|
29 |
+
module.weight.data.fill_(1.0)
|
30 |
+
|
31 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
32 |
+
module.bias.data.zero_()
|
33 |
+
|
34 |
+
def find_nth(vid, underline, n):
|
35 |
+
max_len = len(vid)
|
36 |
+
start = vid.find(underline)
|
37 |
+
while start >= 0 and n > 1:
|
38 |
+
start = vid.find(underline, start+len(underline))
|
39 |
+
n -= 1
|
40 |
+
if start == -1:
|
41 |
+
start = max_len
|
42 |
+
return start
|
43 |
+
|
44 |
+
def element_wise_list_equal(listA, listB):
|
45 |
+
res = []
|
46 |
+
for a, b in zip(listA, listB):
|
47 |
+
if a==b:
|
48 |
+
res.append(True)
|
49 |
+
else:
|
50 |
+
res.append(False)
|
51 |
+
return res
|
52 |
+
|
53 |
+
class CGDETR(nn.Module):
|
54 |
+
""" CG DETR. """
|
55 |
+
|
56 |
+
def __init__(self, transformer, position_embed, txt_position_embed, txt_dim, vid_dim,
|
57 |
+
num_queries, input_dropout, aux_loss=False,
|
58 |
+
contrastive_align_loss=False, contrastive_hdim=64,
|
59 |
+
max_v_l=75, span_loss_type="l1", use_txt_pos=False, n_input_proj=2, aud_dim=0, args=None):
|
60 |
+
""" Initializes the model.
|
61 |
+
Parameters:
|
62 |
+
transformer: torch module of the transformer architecture. See transformer.py
|
63 |
+
position_embed: torch module of the position_embedding, See position_encoding.py
|
64 |
+
txt_position_embed: position_embedding for text
|
65 |
+
txt_dim: int, text query input dimension
|
66 |
+
vid_dim: int, video feature input dimension
|
67 |
+
num_queries: number of object queries, ie detection slot. This is the maximal number of objects
|
68 |
+
CG-DETR can detect in a single video.
|
69 |
+
aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
|
70 |
+
contrastive_align_loss: If true, perform span - tokens contrastive learning
|
71 |
+
contrastive_hdim: dimension used for projecting the embeddings before computing contrastive loss
|
72 |
+
max_v_l: int, maximum #clips in videos
|
73 |
+
span_loss_type: str, one of [l1, ce]
|
74 |
+
l1: (center-x, width) regression.
|
75 |
+
ce: (st_idx, ed_idx) classification.
|
76 |
+
# foreground_thd: float, intersection over prediction >= foreground_thd: labeled as foreground
|
77 |
+
# background_thd: float, intersection over prediction <= background_thd: labeled background
|
78 |
+
"""
|
79 |
+
super().__init__()
|
80 |
+
self.args=args
|
81 |
+
self.num_queries = num_queries
|
82 |
+
self.transformer = transformer
|
83 |
+
self.position_embed = position_embed
|
84 |
+
self.txt_position_embed = txt_position_embed
|
85 |
+
hidden_dim = transformer.d_model
|
86 |
+
self.span_loss_type = span_loss_type
|
87 |
+
self.max_v_l = max_v_l
|
88 |
+
span_pred_dim = 2 if span_loss_type == "l1" else max_v_l * 2
|
89 |
+
self.span_embed = MLP(hidden_dim, hidden_dim, span_pred_dim, 3)
|
90 |
+
self.class_embed = nn.Linear(hidden_dim, 2) # 0: background, 1: foreground
|
91 |
+
self.token_type_embeddings = nn.Embedding(2, hidden_dim)
|
92 |
+
self.token_type_embeddings.apply(init_weights)
|
93 |
+
self.use_txt_pos = use_txt_pos
|
94 |
+
self.n_input_proj = n_input_proj
|
95 |
+
self.query_embed = nn.Embedding(num_queries, 2)
|
96 |
+
relu_args = [True] * 3
|
97 |
+
relu_args[n_input_proj-1] = False
|
98 |
+
self.input_txt_proj = nn.Sequential(*[
|
99 |
+
LinearLayer(txt_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[0]),
|
100 |
+
LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[1]),
|
101 |
+
LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[2])
|
102 |
+
][:n_input_proj])
|
103 |
+
self.input_vid_proj = nn.Sequential(*[
|
104 |
+
LinearLayer(vid_dim + aud_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[0]),
|
105 |
+
LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[1]),
|
106 |
+
LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[2])
|
107 |
+
][:n_input_proj])
|
108 |
+
self.contrastive_align_loss = contrastive_align_loss
|
109 |
+
if contrastive_align_loss:
|
110 |
+
self.contrastive_align_projection_query = nn.Linear(hidden_dim, contrastive_hdim)
|
111 |
+
self.contrastive_align_projection_txt = nn.Linear(hidden_dim, contrastive_hdim)
|
112 |
+
self.contrastive_align_projection_vid = nn.Linear(hidden_dim, contrastive_hdim)
|
113 |
+
|
114 |
+
self.saliency_proj1 = nn.Linear(hidden_dim, hidden_dim)
|
115 |
+
self.saliency_proj2 = nn.Linear(hidden_dim, hidden_dim)
|
116 |
+
self.aux_loss = aux_loss
|
117 |
+
self.hidden_dim = hidden_dim
|
118 |
+
self.global_rep_token = torch.nn.Parameter(torch.randn(args.total_prompts, hidden_dim))
|
119 |
+
self.global_rep_pos = torch.nn.Parameter(torch.randn(1, hidden_dim))
|
120 |
+
self.moment_rep_token = torch.nn.Parameter(torch.randn(hidden_dim))
|
121 |
+
self.moment_rep_pos = torch.nn.Parameter(torch.randn(hidden_dim))
|
122 |
+
|
123 |
+
self.dummy_rep_token = torch.nn.Parameter(torch.randn(args.num_dummies, hidden_dim))
|
124 |
+
self.dummy_rep_pos = torch.nn.Parameter(torch.randn(args.num_dummies, hidden_dim))
|
125 |
+
normalize_before = False
|
126 |
+
self.sent_rep_token = torch.nn.Parameter(torch.randn(hidden_dim))
|
127 |
+
self.sent_rep_pos = torch.nn.Parameter(torch.randn(hidden_dim))
|
128 |
+
|
129 |
+
self.txt_proj_linear = LinearLayer(txt_dim, hidden_dim, layer_norm=True)
|
130 |
+
|
131 |
+
input_txt_sa_proj = TransformerEncoderLayer(hidden_dim, 8, self.args.dim_feedforward, 0.1, "prelu", normalize_before)
|
132 |
+
txtproj_encoder_norm = nn.LayerNorm(hidden_dim) if normalize_before else None
|
133 |
+
self.txtproj_encoder = TransformerEncoder(input_txt_sa_proj, args.dummy_layers, txtproj_encoder_norm)
|
134 |
+
|
135 |
+
scls_encoder_layer = TransformerEncoderLayer(hidden_dim, 8, self.args.dim_feedforward, 0.1, "prelu", normalize_before)
|
136 |
+
scls_encoder_norm = nn.LayerNorm(hidden_dim) if normalize_before else None
|
137 |
+
self.scls_encoder = TransformerEncoder(scls_encoder_layer, args.sent_layers, scls_encoder_norm)
|
138 |
+
|
139 |
+
def forward(self, src_txt, src_txt_mask, src_vid, src_vid_mask, vid=None, qid=None, src_aud=None, src_aud_mask=None, targets=None, prompt_token=None):
|
140 |
+
"""The forward expects two tensors:
|
141 |
+
- src_txt: [batch_size, L_txt, D_txt]
|
142 |
+
- src_txt_mask: [batch_size, L_txt], containing 0 on padded pixels,
|
143 |
+
will convert to 1 as padding later for transformer
|
144 |
+
- src_vid: [batch_size, L_vid, D_vid]
|
145 |
+
- src_vid_mask: [batch_size, L_vid], containing 0 on padded pixels,
|
146 |
+
will convert to 1 as padding later for transformer
|
147 |
+
|
148 |
+
It returns a dict with the following elements:
|
149 |
+
- "pred_spans": The normalized boxes coordinates for all queries, represented as
|
150 |
+
(center_x, width). These values are normalized in [0, 1],
|
151 |
+
relative to the size of each individual image (disregarding possible padding).
|
152 |
+
See PostProcess for information on how to retrieve the unnormalized bounding box.
|
153 |
+
- "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of
|
154 |
+
dictionnaries containing the two above keys for each decoder layer.
|
155 |
+
"""
|
156 |
+
|
157 |
+
|
158 |
+
## For discovering real negative samples
|
159 |
+
device = src_txt_mask.device
|
160 |
+
# import pdb; pdb.set_trace()
|
161 |
+
# if vid is not None: ## for demo (run_on_video/run.py)
|
162 |
+
# _count = [v.count('_') for v in vid]
|
163 |
+
# if self.args.dset_name == 'hl':
|
164 |
+
# _position_to_cut = [find_nth(v, '_', _count[i]-1) for i, v in enumerate(vid)]
|
165 |
+
# ori_vid = [v[:_position_to_cut[i]] for i, v in enumerate(vid)]
|
166 |
+
# else:
|
167 |
+
if vid is not None:
|
168 |
+
ori_vid = [v for v in vid]
|
169 |
+
|
170 |
+
if src_aud is not None:
|
171 |
+
src_vid = torch.cat([src_vid, src_aud], dim=2)
|
172 |
+
|
173 |
+
# --------------------------------
|
174 |
+
src_txt_list = []
|
175 |
+
src_txt_mask_list = []
|
176 |
+
for bs in range(src_txt.shape[0]):
|
177 |
+
idx = int(src_txt_mask[bs].sum().item())
|
178 |
+
src_txt_list.append(torch.cat((src_txt[bs, :idx, :], prompt_token[bs], src_txt[bs, idx:, :]), dim=0))
|
179 |
+
src_txt_mask_list.append(torch.cat((src_txt_mask[bs, :idx], torch.ones(1, dtype=torch.bfloat16).to(device), src_txt_mask[bs, idx:]), dim=0))
|
180 |
+
|
181 |
+
src_txt = torch.stack(src_txt_list, dim=0)
|
182 |
+
src_txt_mask = torch.stack(src_txt_mask_list, dim=0)
|
183 |
+
# --------------------------------
|
184 |
+
|
185 |
+
# src_txt = torch.cat((src_txt, prompt_token), dim=1)
|
186 |
+
# src_txt_mask = torch.cat((src_txt_mask, torch.zeros_like(prompt_token)), dim=1)
|
187 |
+
|
188 |
+
src_vid = self.input_vid_proj(src_vid) # [bsz,vlen,770] -> [bsz,vlen,256]
|
189 |
+
src_txt = self.input_txt_proj(src_txt) # [bsz,qlen,4096] -> [bsz,qlen, 256]
|
190 |
+
|
191 |
+
src_vid = src_vid + self.token_type_embeddings(torch.full_like(src_vid_mask.long(), 1)) # TODO
|
192 |
+
src_txt = src_txt + self.token_type_embeddings(torch.zeros_like(src_txt_mask.long()))
|
193 |
+
|
194 |
+
#
|
195 |
+
pos_vid = self.position_embed(src_vid, src_vid_mask).type(torch.bfloat16) # (bsz, L_vid, d)
|
196 |
+
pos_txt = self.txt_position_embed(src_txt) if self.use_txt_pos else torch.zeros_like(src_txt).type(torch.bfloat16) # (bsz, L_txt, d)
|
197 |
+
|
198 |
+
### insert dummy token in front of txt
|
199 |
+
txt_dummy = self.dummy_rep_token.reshape([1, self.args.num_dummies, self.hidden_dim]).repeat(src_txt.shape[0], 1, 1) # [bsz, 45, 256]
|
200 |
+
src_txt_dummy = torch.cat([txt_dummy, src_txt], dim=1) # [bsz, L_txt+45, 256]
|
201 |
+
mask_txt = torch.tensor([[True] * self.args.num_dummies]).to(src_txt_mask.device).repeat(src_txt_mask.shape[0], 1)
|
202 |
+
src_txt_mask_dummy = torch.cat([mask_txt, src_txt_mask], dim=1) # [bsz, L_txt+45]
|
203 |
+
|
204 |
+
pos_dummy = self.dummy_rep_pos.reshape([1, self.args.num_dummies, self.hidden_dim]).repeat(pos_txt.shape[0], 1, 1).type(torch.bfloat16)
|
205 |
+
pos_txt_dummy = torch.cat([pos_dummy, pos_txt], dim=1)
|
206 |
+
src_txt_dummy = src_txt_dummy.permute(1, 0, 2) # (L, batch_size, d)
|
207 |
+
pos_txt_dummy = pos_txt_dummy.permute(1, 0, 2) # (L, batch_size, d)
|
208 |
+
|
209 |
+
memory = self.txtproj_encoder(src_txt_dummy, src_key_padding_mask=~(src_txt_mask_dummy.bool()), pos=pos_txt_dummy) # (L, batch_size, d)
|
210 |
+
dummy_token = memory[:self.args.num_dummies].permute(1, 0, 2)
|
211 |
+
pos_txt_dummy = pos_txt_dummy.permute(1, 0, 2) # (L, batch_size, d)
|
212 |
+
|
213 |
+
src_txt_dummy = torch.cat([dummy_token, src_txt], dim=1)
|
214 |
+
mask_txt_dummy = torch.tensor([[True]*self.args.num_dummies]).to(src_txt_mask.device).repeat(src_txt_mask.shape[0], 1)
|
215 |
+
src_txt_mask_dummy = torch.cat([mask_txt_dummy, src_txt_mask], dim=1)
|
216 |
+
|
217 |
+
# Input : Concat video, dummy, txt
|
218 |
+
src = torch.cat([src_vid, src_txt_dummy], dim=1) # (bsz, L_vid+L_txt, d)
|
219 |
+
mask = torch.cat([src_vid_mask, src_txt_mask_dummy], dim=1).bool() # (bsz, L_vid+L_txt)
|
220 |
+
pos = torch.cat([pos_vid, pos_txt_dummy], dim=1)
|
221 |
+
|
222 |
+
### sentence token
|
223 |
+
smask_ = torch.tensor([[True]]).to(mask.device).repeat(src_txt_mask.shape[0], 1)
|
224 |
+
smask = torch.cat([smask_, src_txt_mask.bool()], dim=1)
|
225 |
+
ssrc_ = self.sent_rep_token.reshape([1, 1, self.hidden_dim]).repeat(src_txt.shape[0], 1, 1)
|
226 |
+
ssrc = torch.cat([ssrc_, src_txt], dim=1)
|
227 |
+
spos_ = self.sent_rep_pos.reshape([1, 1, self.hidden_dim]).repeat(pos_txt.shape[0], 1, 1)
|
228 |
+
spos = torch.cat([spos_, pos_txt], dim=1)
|
229 |
+
### dummy sentence token
|
230 |
+
smaskd = torch.cat([smask_, mask_txt_dummy.bool()], dim=1)
|
231 |
+
ssrcd = torch.cat([ssrc_, dummy_token], dim=1)
|
232 |
+
sposd = torch.cat([spos_, pos_dummy], dim=1)
|
233 |
+
|
234 |
+
if targets is not None: # train
|
235 |
+
mmask_ = torch.tensor([[True]]).to(mask.device).repeat(src_vid_mask.shape[0], 1)
|
236 |
+
mmask = torch.cat([mmask_, src_vid_mask.bool()], dim=1) # [bsz, L_vid+1]
|
237 |
+
moment_mask_ = torch.clamp(targets["relevant_clips"], 0, 1).bool()
|
238 |
+
moment_mask = torch.cat([mmask_, moment_mask_], dim=1) # [bsz, L_vid+1]
|
239 |
+
# if moment_mask.shape[1] != 76:
|
240 |
+
# import pdb; pdb.set_trace()
|
241 |
+
mmask = mmask * moment_mask
|
242 |
+
|
243 |
+
msrc_ = self.moment_rep_token.reshape([1, 1, self.hidden_dim]).repeat(src_vid.shape[0], 1, 1)
|
244 |
+
msrc = torch.cat([msrc_, src_vid], dim=1)
|
245 |
+
mpos_ = self.moment_rep_pos.reshape([1, 1, self.hidden_dim]).repeat(pos_vid.shape[0], 1, 1)
|
246 |
+
mpos = torch.cat([mpos_, pos_vid], dim=1)
|
247 |
+
|
248 |
+
### for Not moment token ####
|
249 |
+
nmmask_ = torch.tensor([[True]]).to(mask.device).repeat(src_vid_mask.shape[0], 1)
|
250 |
+
nmmask = torch.cat([nmmask_, src_vid_mask.bool()], dim=1)
|
251 |
+
nmoment_mask_ = ~(torch.clamp(targets["relevant_clips"], 0, 1).bool())
|
252 |
+
nmoment_mask = torch.cat([nmmask_, nmoment_mask_], dim=1)
|
253 |
+
nmmask = nmmask * nmoment_mask
|
254 |
+
|
255 |
+
nmsrc_ = self.moment_rep_token.reshape([1, 1, self.hidden_dim]).repeat(src_vid.shape[0], 1, 1)
|
256 |
+
nmsrc = torch.cat([nmsrc_, src_vid], dim=1)
|
257 |
+
nmpos_ = self.moment_rep_pos.reshape([1, 1, self.hidden_dim]).repeat(pos_vid.shape[0], 1, 1)
|
258 |
+
nmpos = torch.cat([nmpos_, pos_vid], dim=1)
|
259 |
+
###########
|
260 |
+
else:
|
261 |
+
moment_mask_ = None
|
262 |
+
|
263 |
+
# for t2vidavg sal token
|
264 |
+
# import pdb; pdb.set_trace()
|
265 |
+
vidsrc_ = torch.zeros((len(src_vid), 1, self.hidden_dim), dtype=torch.bfloat16).to(device)
|
266 |
+
for i in range(len(src_vid)):
|
267 |
+
vidsrc_[i] = src_vid[i][:src_vid_mask.sum(1)[i].long()].mean(0).clone().detach()
|
268 |
+
|
269 |
+
video_length = src_vid.shape[1]
|
270 |
+
if targets is not None: ## train
|
271 |
+
ssrc = ssrc.permute(1, 0, 2) # (L, batch_size, d)
|
272 |
+
spos = spos.permute(1, 0, 2) # (L, batch_size, d)
|
273 |
+
smemory = self.scls_encoder(ssrc, src_key_padding_mask=~smask, pos=spos) # (L, batch_size, d)
|
274 |
+
sentence_txt, smemory_words = smemory[0], smemory[1:] # sentence_txt : (batch_size, d)
|
275 |
+
|
276 |
+
ssrcd = ssrcd.permute(1, 0, 2) # (L, batch_size, d)
|
277 |
+
sposd = sposd.permute(1, 0, 2) # (L, batch_size, d)
|
278 |
+
smemoryd = self.scls_encoder(ssrcd, src_key_padding_mask=~smaskd, pos=sposd) # (L, batch_size, d)
|
279 |
+
sentence_dummy, smemory_words_dummy = smemoryd[0], smemoryd[1:]
|
280 |
+
|
281 |
+
txt_dummy_proj = torch.cat([smemory_words_dummy, smemory_words], dim=0)
|
282 |
+
|
283 |
+
# import pdb; pdb.set_trace()
|
284 |
+
# print(src.dtype)
|
285 |
+
hs, reference, memory, memory_global, attn_weights, memory_moment, nmmemory_moment, mmemory_frames, nmmemory_frames = self.transformer(src, ~mask, self.query_embed.weight, pos, video_length=video_length, moment_idx=targets["relevant_clips"], msrc=msrc, mpos=mpos, mmask=~mmask, nmsrc=nmsrc, nmpos=nmpos, nmmask=~nmmask,
|
286 |
+
ctxtoken=vidsrc_, gtoken=self.global_rep_token, gpos=self.global_rep_pos, vlen=src_vid_mask.sum(1).long())
|
287 |
+
moment2txt_similarity = torch.matmul(mmemory_frames.permute(1, 0, 2), txt_dummy_proj.permute(1, 2, 0))
|
288 |
+
nmoment2txt_similarity = torch.matmul(nmmemory_frames.permute(1, 0, 2), txt_dummy_proj.permute(1, 2, 0))
|
289 |
+
else: ## inference
|
290 |
+
sentence_dummy, sentence_txt, moment2txt_similarity, nmoment2txt_similarity = None, None, None, None
|
291 |
+
hs, reference, memory, memory_global, attn_weights, memory_moment, nmmemory_moment, mmemory_frames, nmmemory_frames = self.transformer(src, ~mask, self.query_embed.weight, pos, video_length=video_length,
|
292 |
+
ctxtoken=vidsrc_, gtoken=self.global_rep_token, gpos=self.global_rep_pos, vlen=src_vid_mask.sum(1).long())
|
293 |
+
outputs_class = self.class_embed(hs) # (#layers, batch_size, #queries, #classes)
|
294 |
+
reference_before_sigmoid = inverse_sigmoid(reference)
|
295 |
+
tmp = self.span_embed(hs)
|
296 |
+
outputs_coord = tmp + reference_before_sigmoid
|
297 |
+
if self.span_loss_type == "l1":
|
298 |
+
outputs_coord = outputs_coord.sigmoid()
|
299 |
+
out = {'pred_logits': outputs_class[-1], 'pred_spans': outputs_coord[-1]}
|
300 |
+
|
301 |
+
txt_mem = memory[:, src_vid.shape[1]:] # (bsz, L_txt, d)
|
302 |
+
vid_mem = memory[:, :src_vid.shape[1]] # (bsz, L_vid, d)
|
303 |
+
if self.contrastive_align_loss:
|
304 |
+
proj_queries = F.normalize(self.contrastive_align_projection_query(hs), p=2, dim=-1)
|
305 |
+
proj_txt_mem = F.normalize(self.contrastive_align_projection_txt(txt_mem), p=2, dim=-1)
|
306 |
+
proj_vid_mem = F.normalize(self.contrastive_align_projection_vid(vid_mem), p=2, dim=-1)
|
307 |
+
out.update(dict(
|
308 |
+
proj_queries=proj_queries[-1],
|
309 |
+
proj_txt_mem=proj_txt_mem,
|
310 |
+
proj_vid_mem=proj_vid_mem
|
311 |
+
))
|
312 |
+
|
313 |
+
if vid is not None: ## for demo (run_on_video/run.py)
|
314 |
+
### Neg Pairs ###
|
315 |
+
neg_vid = ori_vid[1:] + ori_vid[:1]
|
316 |
+
|
317 |
+
real_neg_mask = torch.Tensor(element_wise_list_equal(ori_vid, neg_vid)).to(src_txt_dummy.device)
|
318 |
+
real_neg_mask = real_neg_mask.type(torch.bfloat16)
|
319 |
+
|
320 |
+
real_neg_mask = real_neg_mask == False
|
321 |
+
|
322 |
+
# import pdb; pdb.set_trace()
|
323 |
+
if real_neg_mask.sum() != 0:
|
324 |
+
|
325 |
+
src_txt_dummy_neg = torch.cat([src_txt_dummy[1:], src_txt_dummy[0:1]], dim=0)
|
326 |
+
src_txt_mask_dummy_neg = torch.cat([src_txt_mask_dummy[1:], src_txt_mask_dummy[0:1]], dim=0)
|
327 |
+
src_dummy_neg = torch.cat([src_vid, src_txt_dummy_neg], dim=1)
|
328 |
+
mask_dummy_neg = torch.cat([src_vid_mask, src_txt_mask_dummy_neg], dim=1).bool()
|
329 |
+
pos_neg = pos.clone() # since it does not use actual content
|
330 |
+
|
331 |
+
mask_dummy_neg = mask_dummy_neg[real_neg_mask]
|
332 |
+
src_dummy_neg = src_dummy_neg[real_neg_mask]
|
333 |
+
pos_neg = pos_neg[real_neg_mask]
|
334 |
+
src_txt_mask_dummy_neg = src_txt_mask_dummy_neg[real_neg_mask]
|
335 |
+
|
336 |
+
# import pdb; pdb.set_trace()
|
337 |
+
_, _, memory_neg, memory_global_neg, attn_weights_neg, _, _, _, _ = self.transformer(src_dummy_neg, ~mask_dummy_neg, self.query_embed.weight, pos_neg, video_length=video_length,
|
338 |
+
ctxtoken=vidsrc_[real_neg_mask], gtoken=self.global_rep_token, gpos=self.global_rep_pos, vlen=src_vid_mask[real_neg_mask].sum(1).long())
|
339 |
+
vid_mem_neg = memory_neg[:, :src_vid.shape[1]]
|
340 |
+
out["saliency_scores_neg"] = (torch.sum(self.saliency_proj1(vid_mem_neg) * self.saliency_proj2(memory_global_neg).unsqueeze(1), dim=-1) / np.sqrt(self.hidden_dim))
|
341 |
+
out["src_txt_mask_neg"] = src_txt_mask_dummy_neg
|
342 |
+
|
343 |
+
out["t2vattnvalues_neg"] = (attn_weights_neg[:, :, self.args.num_dummies:] * (src_txt_mask_dummy_neg[:, self.args.num_dummies:].unsqueeze(1).repeat(1, video_length, 1))).sum(2)
|
344 |
+
out["t2vattnvalues_neg"] = torch.clamp(out["t2vattnvalues_neg"], 0, 1)
|
345 |
+
else:
|
346 |
+
out["saliency_scores_neg"] = None
|
347 |
+
out["t2vattnvalues_neg"] = None
|
348 |
+
out["real_neg_mask"] = real_neg_mask
|
349 |
+
else:
|
350 |
+
out["saliency_scores_neg"] = None
|
351 |
+
out["t2vattnvalues_neg"] = None
|
352 |
+
out["real_neg_mask"] = None
|
353 |
+
|
354 |
+
|
355 |
+
out["saliency_scores"] = (torch.sum(self.saliency_proj1(vid_mem) * self.saliency_proj2(memory_global).unsqueeze(1), dim=-1) / np.sqrt(self.hidden_dim))
|
356 |
+
out["memory_moment"] = memory_moment
|
357 |
+
out["nmmemory_moment"] = nmmemory_moment
|
358 |
+
|
359 |
+
## sentence token embeeded with text / dummy
|
360 |
+
out["sentence_txt"] = sentence_txt
|
361 |
+
out["sentence_dummy"] = sentence_dummy
|
362 |
+
out["moment2txt_similarity"] = moment2txt_similarity
|
363 |
+
out["nmoment2txt_similarity"] = nmoment2txt_similarity
|
364 |
+
out["cate_attn_weights"] = attn_weights
|
365 |
+
out["moment_mask"] = moment_mask_
|
366 |
+
out["txt_mask"] = src_txt_mask_dummy
|
367 |
+
|
368 |
+
|
369 |
+
out["t2vattnvalues"] = (attn_weights[:,:,self.args.num_dummies:] * (src_txt_mask.unsqueeze(1).repeat(1, video_length, 1))).sum(2) # (batch_size, L_vid, L_txt) / (batch_size, L_txt)
|
370 |
+
out["t2vattnvalues"] = torch.clamp(out["t2vattnvalues"], 0, 1)
|
371 |
+
out["dummy_tokens"] = dummy_token
|
372 |
+
out["global_rep_tokens"] = self.global_rep_token
|
373 |
+
|
374 |
+
# import pdb; pdb.set_trace()
|
375 |
+
if targets is not None:
|
376 |
+
out["src_vid"] = mmemory_frames.permute(1, 0, 2) * moment_mask_.unsqueeze(2) + nmmemory_frames.permute(1, 0, 2) * (~(moment_mask_.unsqueeze(2).bool())).bfloat16()
|
377 |
+
else:
|
378 |
+
out["src_vid"] = None
|
379 |
+
|
380 |
+
out["video_mask"] = src_vid_mask
|
381 |
+
if self.aux_loss:
|
382 |
+
# assert proj_queries and proj_txt_mem
|
383 |
+
out['aux_outputs'] = [
|
384 |
+
{'pred_logits': a, 'pred_spans': b} for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]
|
385 |
+
if self.contrastive_align_loss:
|
386 |
+
assert proj_queries is not None
|
387 |
+
for idx, d in enumerate(proj_queries[:-1]):
|
388 |
+
out['aux_outputs'][idx].update(dict(proj_queries=d, proj_txt_mem=proj_txt_mem))
|
389 |
+
return out
|
390 |
+
|
391 |
+
class SetCriterion(nn.Module):
|
392 |
+
""" This class computes the loss for DETR.
|
393 |
+
The process happens in two steps:
|
394 |
+
1) we compute hungarian assignment between ground truth boxes and the outputs of the model
|
395 |
+
2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
|
396 |
+
"""
|
397 |
+
|
398 |
+
def __init__(self, matcher, weight_dict, eos_coef, losses, temperature, span_loss_type, max_v_l,
|
399 |
+
saliency_margin=1, use_matcher=True, args=None):
|
400 |
+
""" Create the criterion.
|
401 |
+
Parameters:
|
402 |
+
matcher: module able to compute a matching between targets and proposals
|
403 |
+
weight_dict: dict containing as key the names of the losses and as values their relative weight.
|
404 |
+
eos_coef: relative classification weight applied to the no-object category
|
405 |
+
losses: list of all the losses to be applied. See get_loss for list of available losses.
|
406 |
+
temperature: float, temperature for NCE loss
|
407 |
+
span_loss_type: str, [l1, ce]
|
408 |
+
max_v_l: int,
|
409 |
+
saliency_margin: float
|
410 |
+
"""
|
411 |
+
super().__init__()
|
412 |
+
self.args=args
|
413 |
+
self.matcher = matcher
|
414 |
+
self.weight_dict = weight_dict
|
415 |
+
self.losses = losses
|
416 |
+
self.temperature = temperature
|
417 |
+
self.span_loss_type = span_loss_type
|
418 |
+
self.max_v_l = max_v_l
|
419 |
+
self.saliency_margin = saliency_margin
|
420 |
+
|
421 |
+
# foreground and background classification
|
422 |
+
self.foreground_label = 0
|
423 |
+
self.background_label = 1
|
424 |
+
self.eos_coef = eos_coef
|
425 |
+
empty_weight = torch.ones(2)
|
426 |
+
empty_weight[-1] = self.eos_coef # lower weight for background (index 1, foreground index 0)
|
427 |
+
self.register_buffer('empty_weight', empty_weight)
|
428 |
+
|
429 |
+
# for tvsum,
|
430 |
+
self.use_matcher = use_matcher
|
431 |
+
|
432 |
+
# moment sentence contrastive
|
433 |
+
self.criterion = torch.nn.CrossEntropyLoss()#.to(self.args.device)
|
434 |
+
self.l2_criterion = torch.nn.MSELoss()#.to(self.args.device)
|
435 |
+
self.kld_criterion = torch.nn.KLDivLoss(reduction='none')#.to(self.args.device)
|
436 |
+
self.bce_criterion = nn.BCELoss(reduction='none')
|
437 |
+
|
438 |
+
def loss_spans(self, outputs, targets, indices):
|
439 |
+
"""Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss
|
440 |
+
targets dicts must contain the key "spans" containing a tensor of dim [nb_tgt_spans, 2]
|
441 |
+
The target spans are expected in format (center_x, w), normalized by the image size.
|
442 |
+
"""
|
443 |
+
assert 'pred_spans' in outputs
|
444 |
+
targets = targets["span_labels"]
|
445 |
+
idx = self._get_src_permutation_idx(indices)
|
446 |
+
src_spans = outputs['pred_spans'][idx] # (#spans, max_v_l * 2)
|
447 |
+
tgt_spans = torch.cat([t['spans'][i] for t, (_, i) in zip(targets, indices)], dim=0) # (#spans, 2)
|
448 |
+
if self.span_loss_type == "l1":
|
449 |
+
loss_span = F.l1_loss(src_spans, tgt_spans, reduction='none')
|
450 |
+
loss_giou = 1 - torch.diag(generalized_temporal_iou(span_cxw_to_xx(src_spans), span_cxw_to_xx(tgt_spans)))
|
451 |
+
else: # ce
|
452 |
+
n_spans = src_spans.shape[0]
|
453 |
+
src_spans = src_spans.view(n_spans, 2, self.max_v_l).transpose(1, 2)
|
454 |
+
loss_span = F.cross_entropy(src_spans, tgt_spans, reduction='none')
|
455 |
+
loss_giou = loss_span.new_zeros([1])
|
456 |
+
|
457 |
+
losses = {}
|
458 |
+
losses['loss_span'] = loss_span.mean()
|
459 |
+
losses['loss_giou'] = loss_giou.mean()
|
460 |
+
return losses
|
461 |
+
|
462 |
+
def loss_labels(self, outputs, targets, indices, log=True):
|
463 |
+
"""Classification loss (NLL)
|
464 |
+
targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
|
465 |
+
"""
|
466 |
+
# TODO add foreground and background classifier. use all non-matched as background.
|
467 |
+
assert 'pred_logits' in outputs
|
468 |
+
src_logits = outputs['pred_logits'] # (batch_size, #queries, #classes=2)
|
469 |
+
# idx is a tuple of two 1D tensors (batch_idx, src_idx), of the same length == #objects in batch
|
470 |
+
idx = self._get_src_permutation_idx(indices)
|
471 |
+
target_classes = torch.full(src_logits.shape[:2], self.background_label,
|
472 |
+
dtype=torch.int64, device=src_logits.device) # (batch_size, #queries)
|
473 |
+
target_classes[idx] = self.foreground_label
|
474 |
+
|
475 |
+
loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight, reduction="none")
|
476 |
+
losses = {'loss_label': loss_ce.mean()}
|
477 |
+
|
478 |
+
if log:
|
479 |
+
# TODO this should probably be a separate loss, not hacked in this one here
|
480 |
+
losses['class_error'] = 100 - accuracy(src_logits[idx], self.foreground_label)[0]
|
481 |
+
return losses
|
482 |
+
|
483 |
+
def loss_saliency(self, outputs, targets, indices, log=True):
|
484 |
+
"""higher scores for positive clips"""
|
485 |
+
if "saliency_pos_labels" not in targets:
|
486 |
+
return {"loss_saliency": 0}
|
487 |
+
|
488 |
+
# Neg pair loss
|
489 |
+
if outputs["saliency_scores_neg"] is not None: ## When batch size is not 1 (negative pair exists)
|
490 |
+
vid_token_mask = outputs["video_mask"]
|
491 |
+
real_neg_mask = outputs["real_neg_mask"]
|
492 |
+
saliency_scores_neg = outputs["saliency_scores_neg"].clone() # (N, L)
|
493 |
+
loss_neg_pair = (- torch.log(1. - torch.sigmoid(saliency_scores_neg)) * (vid_token_mask[real_neg_mask])).sum(dim=1).mean()
|
494 |
+
|
495 |
+
saliency_scores = outputs["saliency_scores"].clone() # (N, L)
|
496 |
+
saliency_contrast_label = targets["saliency_all_labels"]
|
497 |
+
|
498 |
+
# real neg
|
499 |
+
realneg_saliency_scores = torch.cat([saliency_scores[real_neg_mask], saliency_scores_neg], dim=1)
|
500 |
+
realneg_saliency_contrast_label = torch.cat([saliency_contrast_label[real_neg_mask], torch.zeros_like(saliency_contrast_label)[real_neg_mask]], dim=1)
|
501 |
+
realneg_vid_token_mask = vid_token_mask[real_neg_mask].repeat([1, 2])
|
502 |
+
realneg_saliency_scores = realneg_vid_token_mask * realneg_saliency_scores + (1. - realneg_vid_token_mask) * -1e+3
|
503 |
+
|
504 |
+
tau = 0.5
|
505 |
+
loss_rank_contrastive = 0.
|
506 |
+
for rand_idx in range(1, 12):
|
507 |
+
drop_mask = ~(realneg_saliency_contrast_label > 100) # no drop
|
508 |
+
pos_mask = (realneg_saliency_contrast_label >= rand_idx) # positive when equal or higher than rand_idx
|
509 |
+
if torch.sum(pos_mask) == 0: # no positive sample
|
510 |
+
continue
|
511 |
+
else:
|
512 |
+
batch_drop_mask = torch.sum(pos_mask, dim=1) > 0 # negative sample indicator
|
513 |
+
|
514 |
+
# drop higher ranks
|
515 |
+
cur_saliency_scores = realneg_saliency_scores * drop_mask / tau + ~drop_mask * -1e+3
|
516 |
+
# numerical stability
|
517 |
+
logits = cur_saliency_scores - torch.max(cur_saliency_scores, dim=1, keepdim=True)[0]
|
518 |
+
# softmax
|
519 |
+
exp_logits = torch.exp(logits)
|
520 |
+
log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True) + 1e-6)
|
521 |
+
|
522 |
+
mean_log_prob_pos = (pos_mask * log_prob * realneg_vid_token_mask).sum(1) / (pos_mask.sum(1) + 1e-6)
|
523 |
+
loss = - mean_log_prob_pos * batch_drop_mask
|
524 |
+
loss_rank_contrastive = loss_rank_contrastive + loss.mean()
|
525 |
+
loss_rank_contrastive = loss_rank_contrastive / 12
|
526 |
+
|
527 |
+
false_neg_mask = ~(real_neg_mask)
|
528 |
+
if false_neg_mask.sum() != 0:
|
529 |
+
if false_neg_mask.sum() == 1:
|
530 |
+
falseneg_saliency_scores = saliency_scores[false_neg_mask].unsqueeze(0)
|
531 |
+
falseneg_saliency_contrast_label = saliency_contrast_label[false_neg_mask].unsqueeze(0)
|
532 |
+
falseneg_vid_token_mask = vid_token_mask[false_neg_mask].unsqueeze(0)
|
533 |
+
falseneg_saliency_scores = falseneg_vid_token_mask * falseneg_saliency_scores + (1. - falseneg_vid_token_mask) * -1e+3
|
534 |
+
else:
|
535 |
+
falseneg_saliency_scores = saliency_scores[false_neg_mask]
|
536 |
+
falseneg_saliency_contrast_label = saliency_contrast_label[false_neg_mask]
|
537 |
+
falseneg_vid_token_mask = vid_token_mask[false_neg_mask]
|
538 |
+
falseneg_saliency_scores = falseneg_vid_token_mask * falseneg_saliency_scores + (1. - falseneg_vid_token_mask) * -1e+3
|
539 |
+
|
540 |
+
tau = 0.5
|
541 |
+
falseneg_loss_rank_contrastive = 0.
|
542 |
+
for rand_idx in range(1, 12):
|
543 |
+
drop_mask = ~(falseneg_saliency_contrast_label > 100) # no drop
|
544 |
+
pos_mask = (falseneg_saliency_contrast_label >= rand_idx) # positive when equal or higher than rand_idx
|
545 |
+
if torch.sum(pos_mask) == 0: # no positive sample
|
546 |
+
continue
|
547 |
+
else:
|
548 |
+
batch_drop_mask = torch.sum(pos_mask, dim=1) > 0 # negative sample indicator
|
549 |
+
|
550 |
+
# drop higher ranks
|
551 |
+
cur_saliency_scores = falseneg_saliency_scores * drop_mask / tau + ~drop_mask * -1e+3
|
552 |
+
# numerical stability
|
553 |
+
logits = cur_saliency_scores - torch.max(cur_saliency_scores, dim=1, keepdim=True)[0]
|
554 |
+
# softmax
|
555 |
+
exp_logits = torch.exp(logits)
|
556 |
+
log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True) + 1e-6)
|
557 |
+
|
558 |
+
mean_log_prob_pos = (pos_mask * log_prob * falseneg_vid_token_mask).sum(1) / (pos_mask.sum(1) + 1e-6)
|
559 |
+
loss = - mean_log_prob_pos * batch_drop_mask
|
560 |
+
falseneg_loss_rank_contrastive = falseneg_loss_rank_contrastive + loss.mean()
|
561 |
+
falseneg_loss_rank_contrastive = falseneg_loss_rank_contrastive / 12
|
562 |
+
loss_rank_contrastive += falseneg_loss_rank_contrastive
|
563 |
+
|
564 |
+
saliency_scores = outputs["saliency_scores"] # (N, L)
|
565 |
+
pos_indices = targets["saliency_pos_labels"] # (N, #pairs)
|
566 |
+
neg_indices = targets["saliency_neg_labels"] # (N, #pairs)
|
567 |
+
num_pairs = pos_indices.shape[1] # typically 2 or 4
|
568 |
+
batch_indices = torch.arange(len(saliency_scores)).to(saliency_scores.device)
|
569 |
+
pos_scores = torch.stack(
|
570 |
+
[saliency_scores[batch_indices, pos_indices[:, col_idx]] for col_idx in range(num_pairs)], dim=1)
|
571 |
+
neg_scores = torch.stack(
|
572 |
+
[saliency_scores[batch_indices, neg_indices[:, col_idx]] for col_idx in range(num_pairs)], dim=1)
|
573 |
+
loss_saliency = torch.clamp(self.saliency_margin + neg_scores - pos_scores, min=0).sum() \
|
574 |
+
/ (len(pos_scores) * num_pairs) * 2 # * 2 to keep the loss the same scale
|
575 |
+
|
576 |
+
# if self.args.dset_name in ['youtube_uni']:
|
577 |
+
# loss_saliency = loss_saliency + loss_rank_contrastive + loss_neg_pair * 0.
|
578 |
+
# else:
|
579 |
+
loss_saliency = loss_saliency + loss_rank_contrastive + loss_neg_pair
|
580 |
+
|
581 |
+
########### Saliency loss to t2v attn weights ##############
|
582 |
+
"""higher scores for positive clips"""
|
583 |
+
vid_token_mask = outputs["video_mask"]
|
584 |
+
# Neg pair loss
|
585 |
+
|
586 |
+
if outputs["t2vattnvalues_neg"] is not None:
|
587 |
+
saliency_scores_neg = outputs["t2vattnvalues_neg"].clone() # (N, L)
|
588 |
+
loss_neg_pair_attn = (- torch.log(1. - saliency_scores_neg) * (vid_token_mask[real_neg_mask])).sum(dim=1).mean()
|
589 |
+
|
590 |
+
saliency_scores = outputs["t2vattnvalues"].clone() # (N, L)
|
591 |
+
saliency_contrast_label = targets["saliency_all_labels"]
|
592 |
+
|
593 |
+
# real neg
|
594 |
+
realneg_saliency_scores = torch.cat([saliency_scores[real_neg_mask], saliency_scores_neg], dim=1)
|
595 |
+
realneg_saliency_contrast_label = torch.cat(
|
596 |
+
[saliency_contrast_label[real_neg_mask], torch.zeros_like(saliency_contrast_label)[real_neg_mask]], dim=1)
|
597 |
+
realneg_vid_token_mask = vid_token_mask[real_neg_mask].repeat([1, 2])
|
598 |
+
realneg_saliency_scores = realneg_vid_token_mask * realneg_saliency_scores + (
|
599 |
+
1. - realneg_vid_token_mask) * -1e+3
|
600 |
+
|
601 |
+
tau = 0.5
|
602 |
+
loss_rank_contrastive_attn = 0.
|
603 |
+
for rand_idx in range(1, 12):
|
604 |
+
drop_mask = ~(realneg_saliency_contrast_label > 100) # no drop
|
605 |
+
pos_mask = (realneg_saliency_contrast_label >= rand_idx) # positive when equal or higher than rand_idx
|
606 |
+
if torch.sum(pos_mask) == 0: # no positive sample
|
607 |
+
continue
|
608 |
+
else:
|
609 |
+
batch_drop_mask = torch.sum(pos_mask, dim=1) > 0 # negative sample indicator
|
610 |
+
|
611 |
+
# drop higher ranks
|
612 |
+
cur_saliency_scores = realneg_saliency_scores * drop_mask / tau + ~drop_mask * -1e+3
|
613 |
+
# numerical stability
|
614 |
+
logits = cur_saliency_scores - torch.max(cur_saliency_scores, dim=1, keepdim=True)[0]
|
615 |
+
# softmax
|
616 |
+
exp_logits = torch.exp(logits)
|
617 |
+
log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True) + 1e-6)
|
618 |
+
|
619 |
+
mean_log_prob_pos = (pos_mask * log_prob * realneg_vid_token_mask).sum(1) / (pos_mask.sum(1) + 1e-6)
|
620 |
+
loss = - mean_log_prob_pos * batch_drop_mask
|
621 |
+
loss_rank_contrastive_attn = loss_rank_contrastive_attn + loss.mean()
|
622 |
+
loss_rank_contrastive_attn = loss_rank_contrastive_attn / 12
|
623 |
+
|
624 |
+
false_neg_mask = ~(real_neg_mask)
|
625 |
+
if false_neg_mask.sum() != 0:
|
626 |
+
if false_neg_mask.sum() == 1:
|
627 |
+
falseneg_saliency_scores = saliency_scores[false_neg_mask].unsqueeze(0)
|
628 |
+
falseneg_saliency_contrast_label = saliency_contrast_label[false_neg_mask].unsqueeze(0)
|
629 |
+
falseneg_vid_token_mask = vid_token_mask[false_neg_mask].unsqueeze(0)
|
630 |
+
falseneg_saliency_scores = falseneg_vid_token_mask * falseneg_saliency_scores + (1. - falseneg_vid_token_mask) * -1e+3
|
631 |
+
else:
|
632 |
+
falseneg_saliency_scores = saliency_scores[false_neg_mask]
|
633 |
+
falseneg_saliency_contrast_label = saliency_contrast_label[false_neg_mask]
|
634 |
+
falseneg_vid_token_mask = vid_token_mask[false_neg_mask]
|
635 |
+
falseneg_saliency_scores = falseneg_vid_token_mask * falseneg_saliency_scores + (1. - falseneg_vid_token_mask) * -1e+3
|
636 |
+
|
637 |
+
tau = 0.5
|
638 |
+
falseneg_loss_rank_contrastive = 0.
|
639 |
+
for rand_idx in range(1, 12):
|
640 |
+
drop_mask = ~(falseneg_saliency_contrast_label > 100) # no drop
|
641 |
+
pos_mask = (falseneg_saliency_contrast_label >= rand_idx) # positive when equal or higher than rand_idx
|
642 |
+
if torch.sum(pos_mask) == 0: # no positive sample
|
643 |
+
continue
|
644 |
+
else:
|
645 |
+
batch_drop_mask = torch.sum(pos_mask, dim=1) > 0 # negative sample indicator
|
646 |
+
|
647 |
+
# drop higher ranks
|
648 |
+
cur_saliency_scores = falseneg_saliency_scores * drop_mask / tau + ~drop_mask * -1e+3
|
649 |
+
# numerical stability
|
650 |
+
logits = cur_saliency_scores - torch.max(cur_saliency_scores, dim=1, keepdim=True)[0]
|
651 |
+
# softmax
|
652 |
+
exp_logits = torch.exp(logits)
|
653 |
+
log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True) + 1e-6)
|
654 |
+
|
655 |
+
mean_log_prob_pos = (pos_mask * log_prob * falseneg_vid_token_mask).sum(1) / (pos_mask.sum(1) + 1e-6)
|
656 |
+
loss = - mean_log_prob_pos * batch_drop_mask
|
657 |
+
falseneg_loss_rank_contrastive = falseneg_loss_rank_contrastive + loss.mean()
|
658 |
+
falseneg_loss_rank_contrastive = falseneg_loss_rank_contrastive / 12
|
659 |
+
loss_rank_contrastive += falseneg_loss_rank_contrastive
|
660 |
+
|
661 |
+
saliency_scores = outputs["t2vattnvalues"] # (N, L)
|
662 |
+
pos_indices = targets["saliency_pos_labels"] # (N, #pairs)
|
663 |
+
neg_indices = targets["saliency_neg_labels"] # (N, #pairs)
|
664 |
+
num_pairs = pos_indices.shape[1] # typically 2 or 4
|
665 |
+
batch_indices = torch.arange(len(saliency_scores)).to(saliency_scores.device)
|
666 |
+
pos_scores = torch.stack(
|
667 |
+
[saliency_scores[batch_indices, pos_indices[:, col_idx]] for col_idx in range(num_pairs)], dim=1)
|
668 |
+
neg_scores = torch.stack(
|
669 |
+
[saliency_scores[batch_indices, neg_indices[:, col_idx]] for col_idx in range(num_pairs)], dim=1)
|
670 |
+
loss_saliency_attn = torch.clamp(self.saliency_margin + neg_scores - pos_scores, min=0).sum() \
|
671 |
+
/ (len(pos_scores) * num_pairs) * 2 # * 2 to keep the loss the same scale
|
672 |
+
|
673 |
+
saliency_binary_label = torch.clamp(targets["saliency_all_labels"], 0, 1)
|
674 |
+
logits = saliency_scores.reshape(-1)
|
675 |
+
labels_x = saliency_binary_label.reshape(-1)
|
676 |
+
BCEcriterion = nn.BCELoss()
|
677 |
+
bceloss = BCEcriterion(logits, labels_x)
|
678 |
+
|
679 |
+
# if self.args.dset_name in ['youtube_uni']:
|
680 |
+
# loss_saliency_attn = loss_rank_contrastive_attn + bceloss + loss_neg_pair_attn * 0 + loss_saliency_attn
|
681 |
+
# else:
|
682 |
+
loss_saliency_attn = loss_rank_contrastive_attn + bceloss + loss_neg_pair_attn + loss_saliency_attn
|
683 |
+
|
684 |
+
loss_saliency += (loss_saliency_attn * self.args.lw_wattn)
|
685 |
+
|
686 |
+
else: ## when batch size == 1
|
687 |
+
vid_token_mask = outputs["video_mask"]
|
688 |
+
saliency_scores = outputs["saliency_scores"].clone() # (N, L)
|
689 |
+
saliency_contrast_label = targets["saliency_all_labels"]
|
690 |
+
|
691 |
+
saliency_scores = vid_token_mask * saliency_scores + (1. - vid_token_mask) * -1e+3
|
692 |
+
|
693 |
+
tau = 0.5
|
694 |
+
loss_rank_contrastive = 0.
|
695 |
+
for rand_idx in range(1, 12):
|
696 |
+
drop_mask = ~(saliency_contrast_label > 100) # no drop
|
697 |
+
pos_mask = (saliency_contrast_label >= rand_idx) # positive when equal or higher than rand_idx
|
698 |
+
if torch.sum(pos_mask) == 0: # no positive sample
|
699 |
+
continue
|
700 |
+
else:
|
701 |
+
batch_drop_mask = torch.sum(pos_mask, dim=1) > 0 # negative sample indicator
|
702 |
+
|
703 |
+
# drop higher ranks
|
704 |
+
cur_saliency_scores = saliency_scores * drop_mask / tau + ~drop_mask * -1e+3
|
705 |
+
# numerical stability
|
706 |
+
logits = cur_saliency_scores - torch.max(cur_saliency_scores, dim=1, keepdim=True)[0]
|
707 |
+
# softmax
|
708 |
+
exp_logits = torch.exp(logits)
|
709 |
+
log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True) + 1e-6)
|
710 |
+
|
711 |
+
mean_log_prob_pos = (pos_mask * log_prob * vid_token_mask).sum(1) / (pos_mask.sum(1) + 1e-6)
|
712 |
+
loss = - mean_log_prob_pos * batch_drop_mask
|
713 |
+
loss_rank_contrastive = loss_rank_contrastive + loss.mean()
|
714 |
+
loss_rank_contrastive = loss_rank_contrastive / 12
|
715 |
+
|
716 |
+
saliency_scores = outputs["saliency_scores"] # (N, L)
|
717 |
+
pos_indices = targets["saliency_pos_labels"] # (N, #pairs)
|
718 |
+
neg_indices = targets["saliency_neg_labels"] # (N, #pairs)
|
719 |
+
num_pairs = pos_indices.shape[1] # typically 2 or 4
|
720 |
+
batch_indices = torch.arange(len(saliency_scores)).to(saliency_scores.device)
|
721 |
+
pos_scores = torch.stack(
|
722 |
+
[saliency_scores[batch_indices, pos_indices[:, col_idx]] for col_idx in range(num_pairs)], dim=1)
|
723 |
+
neg_scores = torch.stack(
|
724 |
+
[saliency_scores[batch_indices, neg_indices[:, col_idx]] for col_idx in range(num_pairs)], dim=1)
|
725 |
+
loss_saliency = torch.clamp(self.saliency_margin + neg_scores - pos_scores, min=0).sum() \
|
726 |
+
/ (len(pos_scores) * num_pairs) * 2 # * 2 to keep the loss the same scale
|
727 |
+
|
728 |
+
loss_saliency = loss_saliency + loss_rank_contrastive
|
729 |
+
########### Saliency loss to t2v attn weights ##############
|
730 |
+
"""higher scores for positive clips"""
|
731 |
+
vid_token_mask = outputs["video_mask"]
|
732 |
+
saliency_scores = outputs["t2vattnvalues"].clone() # (N, L)
|
733 |
+
saliency_contrast_label = targets["saliency_all_labels"]
|
734 |
+
|
735 |
+
saliency_scores = vid_token_mask * saliency_scores + (1. - vid_token_mask) * -1e+3
|
736 |
+
|
737 |
+
tau = 0.5
|
738 |
+
loss_rank_contrastive = 0.
|
739 |
+
for rand_idx in range(1, 12):
|
740 |
+
drop_mask = ~(saliency_contrast_label > 100) # no drop
|
741 |
+
pos_mask = (saliency_contrast_label >= rand_idx) # positive when equal or higher than rand_idx
|
742 |
+
if torch.sum(pos_mask) == 0: # no positive sample
|
743 |
+
continue
|
744 |
+
else:
|
745 |
+
batch_drop_mask = torch.sum(pos_mask, dim=1) > 0 # negative sample indicator
|
746 |
+
|
747 |
+
# drop higher ranks
|
748 |
+
cur_saliency_scores = saliency_scores * drop_mask / tau + ~drop_mask * -1e+3
|
749 |
+
# numerical stability
|
750 |
+
logits = cur_saliency_scores - torch.max(cur_saliency_scores, dim=1, keepdim=True)[0]
|
751 |
+
# softmax
|
752 |
+
exp_logits = torch.exp(logits)
|
753 |
+
log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True) + 1e-6)
|
754 |
+
|
755 |
+
mean_log_prob_pos = (pos_mask * log_prob * vid_token_mask).sum(1) / (pos_mask.sum(1) + 1e-6)
|
756 |
+
loss = - mean_log_prob_pos * batch_drop_mask
|
757 |
+
loss_rank_contrastive = loss_rank_contrastive + loss.mean()
|
758 |
+
loss_rank_contrastive_attn = loss_rank_contrastive / 12
|
759 |
+
|
760 |
+
saliency_scores = outputs["t2vattnvalues"] # (N, L)
|
761 |
+
pos_indices = targets["saliency_pos_labels"] # (N, #pairs)
|
762 |
+
neg_indices = targets["saliency_neg_labels"] # (N, #pairs)
|
763 |
+
num_pairs = pos_indices.shape[1] # typically 2 or 4
|
764 |
+
batch_indices = torch.arange(len(saliency_scores)).to(saliency_scores.device)
|
765 |
+
pos_scores = torch.stack(
|
766 |
+
[saliency_scores[batch_indices, pos_indices[:, col_idx]] for col_idx in range(num_pairs)], dim=1)
|
767 |
+
neg_scores = torch.stack(
|
768 |
+
[saliency_scores[batch_indices, neg_indices[:, col_idx]] for col_idx in range(num_pairs)], dim=1)
|
769 |
+
loss_saliency_attn = torch.clamp(self.saliency_margin + neg_scores - pos_scores, min=0).sum() \
|
770 |
+
/ (len(pos_scores) * num_pairs) * 2 # * 2 to keep the loss the same scale
|
771 |
+
saliency_binary_label = torch.clamp(targets["saliency_all_labels"], 0, 1)
|
772 |
+
logits = saliency_scores.reshape(-1)
|
773 |
+
labels_x = saliency_binary_label.reshape(-1)
|
774 |
+
BCEcriterion = nn.BCELoss()
|
775 |
+
bceloss = BCEcriterion(logits, labels_x)
|
776 |
+
|
777 |
+
loss_saliency_attn = loss_rank_contrastive_attn + bceloss + loss_saliency_attn
|
778 |
+
loss_saliency += (loss_saliency_attn * self.args.lw_wattn)
|
779 |
+
return {"loss_saliency": loss_saliency}
|
780 |
+
|
781 |
+
def loss_contrastive_moment_sentence(self, outputs, targets, indices, log=True):
|
782 |
+
if outputs["memory_moment"] is not None:
|
783 |
+
moment_token = outputs["memory_moment"]
|
784 |
+
nmmemory_moment = outputs["nmmemory_moment"]
|
785 |
+
sentence_token = outputs["sentence_txt"].squeeze(1)
|
786 |
+
sentence_dummy = outputs["sentence_dummy"].squeeze(1) # b, 1, d
|
787 |
+
|
788 |
+
moment_logits = F.normalize(moment_token, dim=1)
|
789 |
+
nmoment_logits = F.normalize(nmmemory_moment, dim=1)
|
790 |
+
sentence_logits = F.normalize(sentence_token, dim=1)
|
791 |
+
dummy_logits = F.normalize(sentence_dummy, dim=1)
|
792 |
+
# import pdb; pdb.set_trace()
|
793 |
+
|
794 |
+
similarity_matrix = torch.matmul(moment_logits, sentence_logits.T) # B B
|
795 |
+
nsimilarity_matrix = torch.matmul(nmoment_logits, sentence_logits.T) # B B
|
796 |
+
similarity_matrix = torch.cat([similarity_matrix, nsimilarity_matrix], dim=1)
|
797 |
+
labels = torch.eye(similarity_matrix.shape[0]).to(sentence_logits.device)
|
798 |
+
nlabels = torch.zeros_like(nsimilarity_matrix).to(sentence_logits.device)
|
799 |
+
labels = torch.cat([labels, nlabels], dim=1).max(dim=1)[1]
|
800 |
+
|
801 |
+
loss_ms_align = self.criterion(similarity_matrix, labels)
|
802 |
+
|
803 |
+
dummy_similarity_matrix = torch.matmul(moment_logits, dummy_logits.T)
|
804 |
+
dummy_nsimilarity_matrix = torch.matmul(nmoment_logits, dummy_logits.T)
|
805 |
+
dummy_similarity_matrix = torch.cat([dummy_similarity_matrix, dummy_nsimilarity_matrix], dim=1)
|
806 |
+
dummy_labels = (~(torch.eye(similarity_matrix.shape[0]).to(sentence_logits.device).bool())).float()
|
807 |
+
dummy_nlabels = torch.ones_like(nsimilarity_matrix).to(sentence_logits.device)
|
808 |
+
dummy_labels = torch.cat([dummy_labels, dummy_nlabels], dim=1).max(dim=1)[1]
|
809 |
+
|
810 |
+
dummy_loss_ms_align = self.criterion(dummy_similarity_matrix, dummy_labels)
|
811 |
+
loss_ms_align += dummy_loss_ms_align
|
812 |
+
video_mask = outputs['video_mask']
|
813 |
+
src_vid = outputs['src_vid'] # [bsz, L_vid, D_vid]
|
814 |
+
moment_mask_ = torch.clamp(targets["relevant_clips"], 0, 1)
|
815 |
+
|
816 |
+
momtokcls_pred = torch.matmul(moment_token.unsqueeze(1), src_vid.permute(0, 2, 1)) # bsz 1 L_vid
|
817 |
+
momtokcls_label = moment_mask_
|
818 |
+
momtokcls_logit = torch.sigmoid(momtokcls_pred)
|
819 |
+
loss_ms_align += (self.bce_criterion(momtokcls_logit.reshape(-1), momtokcls_label.reshape(-1)) * video_mask.reshape(-1)).mean()
|
820 |
+
|
821 |
+
else:
|
822 |
+
loss_ms_align = 0.
|
823 |
+
return {"loss_ms_align": loss_ms_align}
|
824 |
+
#
|
825 |
+
|
826 |
+
def loss_moment2txt_sim_distill(self, outputs, targets, indices, log=True):
|
827 |
+
if outputs["moment2txt_similarity"] is not None:
|
828 |
+
moment2txt_similarity = outputs["moment2txt_similarity"] # bsz L_clip 22
|
829 |
+
moment_mask = outputs["moment_mask"].int() # bsz L_clip 1
|
830 |
+
txt_mask = outputs["txt_mask"].unsqueeze(1).repeat(1, outputs["cate_attn_weights"].size(1), 1) # bsz l_t
|
831 |
+
|
832 |
+
attn_weights = outputs["cate_attn_weights"] # bsz L_clip 22
|
833 |
+
b, L_vid, L_txt = attn_weights.size()
|
834 |
+
loss_distill = self.kld_criterion(
|
835 |
+
torch.log(attn_weights + 1e-6).reshape(b * L_vid, -1),
|
836 |
+
torch.softmax(moment2txt_similarity, dim=-1).clone().detach().reshape(b * L_vid, -1)).mean(1) * moment_mask.reshape(-1)
|
837 |
+
loss_distill = loss_distill.sum() / moment_mask.sum()
|
838 |
+
|
839 |
+
else:
|
840 |
+
loss_distill = 0.
|
841 |
+
return {"loss_distill": loss_distill}
|
842 |
+
|
843 |
+
def loss_orthogonal_dummy(self, outputs, targets, indices, log=True):
|
844 |
+
dummy_tokens = outputs["dummy_tokens"] # (n_dum, dim)
|
845 |
+
if dummy_tokens.size(1) != 1:
|
846 |
+
dummy_tokens_norm = dummy_tokens / dummy_tokens.norm(dim=2)[:, :, None]
|
847 |
+
dummy_tokens_sim = torch.matmul(dummy_tokens_norm, dummy_tokens_norm.permute(0, 2, 1).detach())
|
848 |
+
for i in range(len(dummy_tokens_sim)):
|
849 |
+
dummy_tokens_sim[i].fill_diagonal_(0)
|
850 |
+
loss_dummy_ortho = dummy_tokens_sim.abs().mean()
|
851 |
+
else:
|
852 |
+
loss_dummy_ortho=0.
|
853 |
+
global_tokens = outputs["global_rep_tokens"]
|
854 |
+
|
855 |
+
global_tokens_norm = global_tokens / global_tokens.norm(dim=1)[:, None]
|
856 |
+
global_tokens_sim = torch.matmul(global_tokens_norm, global_tokens_norm.permute(1, 0).detach())
|
857 |
+
for i in range(len(global_tokens_sim)):
|
858 |
+
global_tokens_sim.fill_diagonal_(0)
|
859 |
+
loss_dummy_ortho += global_tokens_sim.abs().mean()
|
860 |
+
return {"loss_orthogonal_dummy": loss_dummy_ortho}
|
861 |
+
|
862 |
+
def loss_contrastive_align(self, outputs, targets, indices, log=True):
|
863 |
+
"""encourage higher scores between matched query span and input text"""
|
864 |
+
normalized_text_embed = outputs["proj_txt_mem"] # (bsz, #tokens, d) text tokens
|
865 |
+
normalized_img_embed = outputs["proj_queries"] # (bsz, #queries, d)
|
866 |
+
logits = torch.einsum(
|
867 |
+
"bmd,bnd->bmn", normalized_img_embed, normalized_text_embed) # (bsz, #queries, #tokens)
|
868 |
+
logits = logits.sum(2) / self.temperature # (bsz, #queries)
|
869 |
+
idx = self._get_src_permutation_idx(indices)
|
870 |
+
positive_map = torch.zeros_like(logits, dtype=torch.bool)
|
871 |
+
positive_map[idx] = True
|
872 |
+
positive_logits = logits.masked_fill(~positive_map, 0)
|
873 |
+
|
874 |
+
pos_term = positive_logits.sum(1) # (bsz, )
|
875 |
+
num_pos = positive_map.sum(1) # (bsz, )
|
876 |
+
neg_term = logits.logsumexp(1) # (bsz, )
|
877 |
+
loss_nce = - pos_term / num_pos + neg_term # (bsz, )
|
878 |
+
losses = {"loss_contrastive_align": loss_nce.mean()}
|
879 |
+
return losses
|
880 |
+
|
881 |
+
def loss_contrastive_align_vid_txt(self, outputs, targets, indices, log=True):
|
882 |
+
"""encourage higher scores between matched query span and input text"""
|
883 |
+
normalized_text_embed = outputs["proj_txt_mem"] # (bsz, #tokens, d) text tokens
|
884 |
+
normalized_img_embed = outputs["proj_queries"] # (bsz, #queries, d)
|
885 |
+
logits = torch.einsum(
|
886 |
+
"bmd,bnd->bmn", normalized_img_embed, normalized_text_embed) # (bsz, #queries, #tokens)
|
887 |
+
logits = logits.sum(2) / self.temperature # (bsz, #queries)
|
888 |
+
idx = self._get_src_permutation_idx(indices)
|
889 |
+
positive_map = torch.zeros_like(logits, dtype=torch.bool)
|
890 |
+
positive_map[idx] = True
|
891 |
+
positive_logits = logits.masked_fill(~positive_map, 0)
|
892 |
+
|
893 |
+
pos_term = positive_logits.sum(1) # (bsz, )
|
894 |
+
num_pos = positive_map.sum(1) # (bsz, )
|
895 |
+
neg_term = logits.logsumexp(1) # (bsz, )
|
896 |
+
loss_nce = - pos_term / num_pos + neg_term # (bsz, )
|
897 |
+
losses = {"loss_contrastive_align": loss_nce.mean()}
|
898 |
+
return losses
|
899 |
+
|
900 |
+
def _get_src_permutation_idx(self, indices):
|
901 |
+
# permute predictions following indices
|
902 |
+
batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
|
903 |
+
src_idx = torch.cat([src for (src, _) in indices])
|
904 |
+
return batch_idx, src_idx # two 1D tensors of the same length
|
905 |
+
|
906 |
+
def _get_tgt_permutation_idx(self, indices):
|
907 |
+
# permute targets following indices
|
908 |
+
batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
|
909 |
+
tgt_idx = torch.cat([tgt for (_, tgt) in indices])
|
910 |
+
return batch_idx, tgt_idx
|
911 |
+
|
912 |
+
def get_loss(self, loss, outputs, targets, indices, **kwargs):
|
913 |
+
loss_map = {
|
914 |
+
"spans": self.loss_spans,
|
915 |
+
"labels": self.loss_labels,
|
916 |
+
"contrastive_align": self.loss_contrastive_align,
|
917 |
+
"saliency": self.loss_saliency,
|
918 |
+
"ms_align": self.loss_contrastive_moment_sentence,
|
919 |
+
"distill": self.loss_moment2txt_sim_distill,
|
920 |
+
"orthogonal_dummy":self.loss_orthogonal_dummy
|
921 |
+
}
|
922 |
+
assert loss in loss_map, f'do you really want to compute {loss} loss?'
|
923 |
+
return loss_map[loss](outputs, targets, indices, **kwargs)
|
924 |
+
|
925 |
+
def forward(self, outputs, targets):
|
926 |
+
""" This performs the loss computation.
|
927 |
+
Parameters:
|
928 |
+
outputs: dict of tensors, see the output specification of the model for the format
|
929 |
+
targets: list of dicts, such that len(targets) == batch_size.
|
930 |
+
The expected keys in each dict depends on the losses applied, see each loss' doc
|
931 |
+
"""
|
932 |
+
outputs_without_aux = {k: v for k, v in outputs.items() if k != 'aux_outputs'}
|
933 |
+
|
934 |
+
# Retrieve the matching between the outputs of the last layer and the targets
|
935 |
+
# list(tuples), each tuple is (pred_span_indices, tgt_span_indices)
|
936 |
+
|
937 |
+
# only for HL, do not use matcher
|
938 |
+
if self.use_matcher:
|
939 |
+
# import pdb; pdb.set_trace()
|
940 |
+
indices = self.matcher(outputs_without_aux, targets)
|
941 |
+
losses_target = self.losses
|
942 |
+
else:
|
943 |
+
indices = None
|
944 |
+
losses_target = ["saliency"]
|
945 |
+
|
946 |
+
# Compute all the requested losses
|
947 |
+
losses = {}
|
948 |
+
for loss in losses_target:
|
949 |
+
losses.update(self.get_loss(loss, outputs, targets, indices))
|
950 |
+
|
951 |
+
# In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
|
952 |
+
if 'aux_outputs' in outputs:
|
953 |
+
for i, aux_outputs in enumerate(outputs['aux_outputs']):
|
954 |
+
# indices = self.matcher(aux_outputs, targets)
|
955 |
+
if self.use_matcher:
|
956 |
+
indices = self.matcher(aux_outputs, targets)
|
957 |
+
losses_target = self.losses
|
958 |
+
else:
|
959 |
+
indices = None
|
960 |
+
losses_target = ["saliency", "ms_align", "distill", "orthogonal_dummy"]
|
961 |
+
for loss in losses_target:
|
962 |
+
if "saliency" == loss: # skip as it is only in the top layer
|
963 |
+
continue
|
964 |
+
if "ms_align" == loss:
|
965 |
+
continue
|
966 |
+
if "distill" == loss:
|
967 |
+
continue
|
968 |
+
if "orthogonal_dummy" == loss:
|
969 |
+
continue
|
970 |
+
kwargs = {}
|
971 |
+
l_dict = self.get_loss(loss, aux_outputs, targets, indices, **kwargs)
|
972 |
+
l_dict = {k + f'_{i}': v for k, v in l_dict.items()}
|
973 |
+
losses.update(l_dict)
|
974 |
+
return losses
|
975 |
+
|
976 |
+
|
977 |
+
class MLP(nn.Module):
|
978 |
+
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
|
979 |
+
super().__init__()
|
980 |
+
self.num_layers = num_layers
|
981 |
+
h = [hidden_dim] * (num_layers - 1)
|
982 |
+
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
|
983 |
+
|
984 |
+
def forward(self, x):
|
985 |
+
for i, layer in enumerate(self.layers):
|
986 |
+
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
987 |
+
return x
|
988 |
+
|
989 |
+
class LinearLayer(nn.Module):
|
990 |
+
"""linear layer configurable with layer normalization, dropout, ReLU."""
|
991 |
+
|
992 |
+
def __init__(self, input_dim, output_dim, layer_norm=True, dropout=0.1, relu=True):
|
993 |
+
super(LinearLayer, self).__init__()
|
994 |
+
self.relu = relu
|
995 |
+
self.layer_norm = layer_norm
|
996 |
+
if layer_norm:
|
997 |
+
self.LayerNorm = nn.LayerNorm(input_dim)
|
998 |
+
layers = [
|
999 |
+
nn.Dropout(dropout),
|
1000 |
+
nn.Linear(input_dim, output_dim)
|
1001 |
+
]
|
1002 |
+
self.net = nn.Sequential(*layers)
|
1003 |
+
|
1004 |
+
def forward(self, x):
|
1005 |
+
"""(N, L, D)"""
|
1006 |
+
|
1007 |
+
if self.layer_norm:
|
1008 |
+
x = self.LayerNorm(x)
|
1009 |
+
x = self.net(x)
|
1010 |
+
if self.relu:
|
1011 |
+
x = F.relu(x, inplace=True)
|
1012 |
+
return x # (N, L, D)
|
1013 |
+
|
1014 |
+
class CGDETRConfig:
|
1015 |
+
def __init__(self, dset_name='charadesSTA', eval_split_name='val', data_ratio=1.0,
|
1016 |
+
results_root='results', exp_id=None, max_es_cnt=200, eval_epoch=5,
|
1017 |
+
grad_clip=0.1, eval_untrained=False, resume_all=False, start_epoch=None,
|
1018 |
+
max_q_l=-1, max_v_l=-1, clip_length=1, max_windows=5, train_path=None,
|
1019 |
+
eval_path=None, no_norm_vfeat=False, no_norm_tfeat=False, v_feat_dirs=None,
|
1020 |
+
t_feat_dir=None, v_feat_dim=770, t_feat_dim=4096, ctx_mode='video_tef',
|
1021 |
+
position_embedding='sine', enc_layers=3, dec_layers=3, t2v_layers=2,
|
1022 |
+
sent_layers=1, moment_layers=1, dummy_layers=2, dim_feedforward=1024,
|
1023 |
+
hidden_dim=256, input_dropout=0.5, dropout=0.1, txt_drop_ratio=0,
|
1024 |
+
use_txt_pos=False, nheads=8, num_queries=10, num_dummies=45,
|
1025 |
+
total_prompts=10, num_prompts=1, pre_norm=False, n_input_proj=2,
|
1026 |
+
contrastive_hdim=64, temperature=0.07, saliency_margin=0.2, aux_loss=True,
|
1027 |
+
span_loss_type='l1', contrastive_align_loss=False, set_cost_span=10,
|
1028 |
+
set_cost_giou=1, set_cost_class=4, lw_saliency=4, lw_wattn=1.0,
|
1029 |
+
lw_ms_align=1.0, lw_distill=1.0, span_loss_coef=10, giou_loss_coef=1,
|
1030 |
+
label_loss_coef=4, eos_coef=0.1, contrastive_align_loss_coef=0.02,
|
1031 |
+
no_sort_results=False, max_before_nms=10, max_after_nms=10,
|
1032 |
+
conf_thd=0.0, nms_thd=-1):
|
1033 |
+
|
1034 |
+
self.dset_name = dset_name
|
1035 |
+
self.eval_split_name = eval_split_name
|
1036 |
+
self.data_ratio = data_ratio
|
1037 |
+
self.results_root = results_root
|
1038 |
+
self.exp_id = exp_id
|
1039 |
+
self.max_es_cnt = max_es_cnt
|
1040 |
+
self.eval_epoch = eval_epoch
|
1041 |
+
self.grad_clip = grad_clip
|
1042 |
+
self.eval_untrained = eval_untrained
|
1043 |
+
self.resume_all = resume_all
|
1044 |
+
self.start_epoch = start_epoch
|
1045 |
+
self.max_q_l = max_q_l
|
1046 |
+
self.max_v_l = max_v_l
|
1047 |
+
self.clip_length = clip_length
|
1048 |
+
self.max_windows = max_windows
|
1049 |
+
self.train_path = train_path
|
1050 |
+
self.eval_path = eval_path
|
1051 |
+
self.no_norm_vfeat = no_norm_vfeat
|
1052 |
+
self.no_norm_tfeat = no_norm_tfeat
|
1053 |
+
self.v_feat_dirs = v_feat_dirs
|
1054 |
+
self.t_feat_dir = t_feat_dir
|
1055 |
+
self.v_feat_dim = v_feat_dim
|
1056 |
+
self.t_feat_dim = t_feat_dim
|
1057 |
+
self.ctx_mode = ctx_mode
|
1058 |
+
self.position_embedding = position_embedding
|
1059 |
+
self.enc_layers = enc_layers
|
1060 |
+
self.dec_layers = dec_layers
|
1061 |
+
self.t2v_layers = t2v_layers
|
1062 |
+
self.sent_layers = sent_layers
|
1063 |
+
self.moment_layers = moment_layers
|
1064 |
+
self.dummy_layers = dummy_layers
|
1065 |
+
self.dim_feedforward = dim_feedforward
|
1066 |
+
self.hidden_dim = hidden_dim
|
1067 |
+
self.input_dropout = input_dropout
|
1068 |
+
self.dropout = dropout
|
1069 |
+
self.txt_drop_ratio = txt_drop_ratio
|
1070 |
+
self.use_txt_pos = use_txt_pos
|
1071 |
+
self.nheads = nheads
|
1072 |
+
self.num_queries = num_queries
|
1073 |
+
self.num_dummies = num_dummies
|
1074 |
+
self.total_prompts = total_prompts
|
1075 |
+
self.num_prompts = num_prompts
|
1076 |
+
self.pre_norm = pre_norm
|
1077 |
+
self.n_input_proj = n_input_proj
|
1078 |
+
self.contrastive_hdim = contrastive_hdim
|
1079 |
+
self.temperature = temperature
|
1080 |
+
self.saliency_margin = saliency_margin
|
1081 |
+
self.aux_loss = aux_loss
|
1082 |
+
self.span_loss_type = span_loss_type
|
1083 |
+
self.contrastive_align_loss = contrastive_align_loss
|
1084 |
+
self.set_cost_span = set_cost_span
|
1085 |
+
self.set_cost_giou = set_cost_giou
|
1086 |
+
self.set_cost_class = set_cost_class
|
1087 |
+
self.lw_saliency = lw_saliency
|
1088 |
+
self.lw_wattn = lw_wattn
|
1089 |
+
self.lw_ms_align = lw_ms_align
|
1090 |
+
self.lw_distill = lw_distill
|
1091 |
+
self.span_loss_coef = span_loss_coef
|
1092 |
+
self.giou_loss_coef = giou_loss_coef
|
1093 |
+
self.label_loss_coef = label_loss_coef
|
1094 |
+
self.eos_coef = eos_coef
|
1095 |
+
self.contrastive_align_loss_coef = contrastive_align_loss_coef
|
1096 |
+
self.no_sort_results = no_sort_results
|
1097 |
+
self.max_before_nms = max_before_nms
|
1098 |
+
self.max_after_nms = max_after_nms
|
1099 |
+
self.conf_thd = conf_thd
|
1100 |
+
self.nms_thd = nms_thd
|
1101 |
+
|
1102 |
+
def build_cgdetr_model():
|
1103 |
+
# device = torch.device(args.device)
|
1104 |
+
# import pdb; pdb.set_trace()
|
1105 |
+
args = CGDETRConfig()
|
1106 |
+
|
1107 |
+
transformer = build_transformer(args)
|
1108 |
+
position_embedding, txt_position_embedding = build_position_encoding(args)
|
1109 |
+
|
1110 |
+
# if args.a_feat_dir is None:
|
1111 |
+
model = CGDETR(
|
1112 |
+
transformer,
|
1113 |
+
position_embedding,
|
1114 |
+
txt_position_embedding,
|
1115 |
+
txt_dim=args.t_feat_dim,
|
1116 |
+
vid_dim=args.v_feat_dim,
|
1117 |
+
num_queries=args.num_queries,
|
1118 |
+
input_dropout=args.input_dropout,
|
1119 |
+
aux_loss=args.aux_loss,
|
1120 |
+
contrastive_align_loss=args.contrastive_align_loss,
|
1121 |
+
contrastive_hdim=args.contrastive_hdim,
|
1122 |
+
span_loss_type=args.span_loss_type,
|
1123 |
+
use_txt_pos=args.use_txt_pos,
|
1124 |
+
n_input_proj=args.n_input_proj,
|
1125 |
+
args=args
|
1126 |
+
)
|
1127 |
+
# else:
|
1128 |
+
# model = CGDETR(
|
1129 |
+
# transformer,
|
1130 |
+
# position_embedding,
|
1131 |
+
# txt_position_embedding,
|
1132 |
+
# txt_dim=args.t_feat_dim,
|
1133 |
+
# vid_dim=args.v_feat_dim,
|
1134 |
+
# aud_dim=args.a_feat_dim,
|
1135 |
+
# num_queries=args.num_queries,
|
1136 |
+
# input_dropout=args.input_dropout,
|
1137 |
+
# aux_loss=args.aux_loss,
|
1138 |
+
# contrastive_align_loss=args.contrastive_align_loss,
|
1139 |
+
# contrastive_hdim=args.contrastive_hdim,
|
1140 |
+
# span_loss_type=args.span_loss_type,
|
1141 |
+
# use_txt_pos=args.use_txt_pos,
|
1142 |
+
# n_input_proj=args.n_input_proj,
|
1143 |
+
# args=args
|
1144 |
+
# )
|
1145 |
+
|
1146 |
+
matcher = build_matcher(args)
|
1147 |
+
weight_dict = {"loss_span": args.span_loss_coef,
|
1148 |
+
"loss_giou": args.giou_loss_coef,
|
1149 |
+
"loss_label": args.label_loss_coef,
|
1150 |
+
"loss_saliency": args.lw_saliency,
|
1151 |
+
"loss_ms_align": args.lw_ms_align,
|
1152 |
+
"loss_distill": args.lw_distill,
|
1153 |
+
"loss_orthogonal_dummy":args.lw_distill}
|
1154 |
+
if args.contrastive_align_loss:
|
1155 |
+
weight_dict["loss_contrastive_align"] = args.contrastive_align_loss_coef
|
1156 |
+
|
1157 |
+
if args.aux_loss:
|
1158 |
+
aux_weight_dict = {}
|
1159 |
+
for i in range(args.dec_layers - 1):
|
1160 |
+
aux_weight_dict.update({k + f'_{i}': v for k, v in weight_dict.items() if k != "loss_saliency"})
|
1161 |
+
weight_dict.update(aux_weight_dict)
|
1162 |
+
|
1163 |
+
losses = ['spans', 'labels', 'saliency', 'ms_align', 'distill', 'orthogonal_dummy']
|
1164 |
+
if args.contrastive_align_loss:
|
1165 |
+
losses += ["contrastive_align"]
|
1166 |
+
|
1167 |
+
# For highlight detection datasets
|
1168 |
+
# use_matcher = not (args.dset_name in ['youtube_uni', 'tvsum'])
|
1169 |
+
use_matcher = True
|
1170 |
+
|
1171 |
+
criterion = SetCriterion(
|
1172 |
+
matcher=matcher, weight_dict=weight_dict, losses=losses,
|
1173 |
+
eos_coef=args.eos_coef, temperature=args.temperature,
|
1174 |
+
span_loss_type=args.span_loss_type, max_v_l=args.max_v_l,
|
1175 |
+
saliency_margin=args.saliency_margin, use_matcher=use_matcher, args=args
|
1176 |
+
)
|
1177 |
+
# criterion.to(device)
|
1178 |
+
return model, criterion
|
third_party/cgdetr/cg_detr/position_encoding.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
2 |
+
"""
|
3 |
+
Various positional encodings for the transformer.
|
4 |
+
"""
|
5 |
+
import math
|
6 |
+
import torch
|
7 |
+
from torch import nn
|
8 |
+
|
9 |
+
|
10 |
+
class TrainablePositionalEncoding(nn.Module):
|
11 |
+
"""Construct the embeddings from word, position and token_type embeddings.
|
12 |
+
"""
|
13 |
+
def __init__(self, max_position_embeddings, hidden_size, dropout=0.1):
|
14 |
+
super(TrainablePositionalEncoding, self).__init__()
|
15 |
+
self.position_embeddings = nn.Embedding(max_position_embeddings, hidden_size)
|
16 |
+
self.LayerNorm = nn.LayerNorm(hidden_size)
|
17 |
+
self.dropout = nn.Dropout(dropout)
|
18 |
+
|
19 |
+
def forward(self, input_feat):
|
20 |
+
"""
|
21 |
+
Args:
|
22 |
+
input_feat: (N, L, D)
|
23 |
+
"""
|
24 |
+
bsz, seq_length = input_feat.shape[:2]
|
25 |
+
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_feat.device)
|
26 |
+
position_ids = position_ids.unsqueeze(0).repeat(bsz, 1) # (N, L)
|
27 |
+
|
28 |
+
position_embeddings = self.position_embeddings(position_ids)
|
29 |
+
|
30 |
+
embeddings = self.LayerNorm(input_feat + position_embeddings)
|
31 |
+
embeddings = self.dropout(embeddings)
|
32 |
+
return embeddings
|
33 |
+
|
34 |
+
|
35 |
+
class PositionEmbeddingSine(nn.Module):
|
36 |
+
"""
|
37 |
+
This is a more standard version of the position embedding, very similar to the one
|
38 |
+
used by the Attention is all you need paper, generalized to work on images. (To 1D sequences)
|
39 |
+
"""
|
40 |
+
def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
|
41 |
+
super().__init__()
|
42 |
+
self.num_pos_feats = num_pos_feats
|
43 |
+
self.temperature = temperature
|
44 |
+
self.normalize = normalize
|
45 |
+
if scale is not None and normalize is False:
|
46 |
+
raise ValueError("normalize should be True if scale is passed")
|
47 |
+
if scale is None:
|
48 |
+
scale = 2 * math.pi
|
49 |
+
self.scale = scale
|
50 |
+
|
51 |
+
def forward(self, x, mask):
|
52 |
+
"""
|
53 |
+
Args:
|
54 |
+
x: torch.tensor, (batch_size, L, d)
|
55 |
+
mask: torch.tensor, (batch_size, L), with 1 as valid
|
56 |
+
|
57 |
+
Returns:
|
58 |
+
|
59 |
+
"""
|
60 |
+
assert mask is not None
|
61 |
+
x_embed = mask.cumsum(1, dtype=torch.float32) # (bsz, L)
|
62 |
+
if self.normalize:
|
63 |
+
eps = 1e-6
|
64 |
+
x_embed = x_embed / (x_embed[:, -1:] + eps) * self.scale
|
65 |
+
|
66 |
+
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
67 |
+
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
|
68 |
+
|
69 |
+
pos_x = x_embed[:, :, None] / dim_t # (bsz, L, num_pos_feats)
|
70 |
+
pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2) # (bsz, L, num_pos_feats*2)
|
71 |
+
# import ipdb; ipdb.set_trace()
|
72 |
+
return pos_x # .permute(0, 2, 1) # (bsz, num_pos_feats*2, L)
|
73 |
+
|
74 |
+
|
75 |
+
class PositionEmbeddingLearned(nn.Module):
|
76 |
+
"""
|
77 |
+
Absolute pos embedding, learned.
|
78 |
+
"""
|
79 |
+
def __init__(self, num_pos_feats=256):
|
80 |
+
super().__init__()
|
81 |
+
self.row_embed = nn.Embedding(50, num_pos_feats)
|
82 |
+
self.col_embed = nn.Embedding(50, num_pos_feats)
|
83 |
+
self.reset_parameters()
|
84 |
+
|
85 |
+
def reset_parameters(self):
|
86 |
+
nn.init.uniform_(self.row_embed.weight)
|
87 |
+
nn.init.uniform_(self.col_embed.weight)
|
88 |
+
|
89 |
+
def forward(self, x, mask):
|
90 |
+
h, w = x.shape[-2:]
|
91 |
+
i = torch.arange(w, device=x.device)
|
92 |
+
j = torch.arange(h, device=x.device)
|
93 |
+
x_emb = self.col_embed(i)
|
94 |
+
y_emb = self.row_embed(j)
|
95 |
+
pos = torch.cat([
|
96 |
+
x_emb.unsqueeze(0).repeat(h, 1, 1),
|
97 |
+
y_emb.unsqueeze(1).repeat(1, w, 1),
|
98 |
+
], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1)
|
99 |
+
return pos
|
100 |
+
|
101 |
+
|
102 |
+
def build_position_encoding(args):
|
103 |
+
N_steps = args.hidden_dim
|
104 |
+
if args.position_embedding in ('v2', 'sine'):
|
105 |
+
# TODO find a better way of exposing other arguments
|
106 |
+
position_embedding = PositionEmbeddingSine(N_steps, normalize=True)
|
107 |
+
# elif args.position_embedding in ('v3', 'learned'):
|
108 |
+
# position_embedding = PositionEmbeddingLearned(N_steps)
|
109 |
+
else:
|
110 |
+
raise ValueError(f"not supported {args.position_embedding}")
|
111 |
+
if args.max_q_l == -1:
|
112 |
+
args.max_q_l = 100
|
113 |
+
txt_pos_embed = TrainablePositionalEncoding(
|
114 |
+
max_position_embeddings=args.max_q_l,
|
115 |
+
hidden_size=args.hidden_dim, dropout=args.input_dropout)
|
116 |
+
return position_embedding, txt_pos_embed
|
third_party/cgdetr/cg_detr/postprocessing_cg_detr.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pprint
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from third_party.cgdetr.utils.basic_utils import load_jsonl
|
5 |
+
from third_party.cgdetr.standalone_eval.eval import eval_submission
|
6 |
+
from tqdm import tqdm
|
7 |
+
|
8 |
+
|
9 |
+
class PostProcessorDETR:
|
10 |
+
def __init__(self, clip_length=2, min_ts_val=0, max_ts_val=150,
|
11 |
+
min_w_l=2, max_w_l=70, move_window_method="center",
|
12 |
+
process_func_names=("clip_window_l", "clip_ts", "round_multiple")):
|
13 |
+
self.clip_length = clip_length
|
14 |
+
self.min_ts_val = min_ts_val
|
15 |
+
self.max_ts_val = max_ts_val
|
16 |
+
self.min_w_l = min_w_l
|
17 |
+
self.max_w_l = max_w_l
|
18 |
+
self.move_window_method = move_window_method
|
19 |
+
self.process_func_names = process_func_names
|
20 |
+
self.name2func = dict(
|
21 |
+
clip_ts=self.clip_min_max_timestamps,
|
22 |
+
round_multiple=self.round_to_multiple_clip_lengths,
|
23 |
+
clip_window_l=self.clip_window_lengths
|
24 |
+
)
|
25 |
+
|
26 |
+
def __call__(self, lines):
|
27 |
+
processed_lines = []
|
28 |
+
for line in tqdm(lines, desc=f"convert to multiples of clip_length={self.clip_length}"):
|
29 |
+
windows_and_scores = torch.tensor(line["pred_relevant_windows"])
|
30 |
+
windows = windows_and_scores[:, :2]
|
31 |
+
for func_name in self.process_func_names:
|
32 |
+
windows = self.name2func[func_name](windows)
|
33 |
+
line["pred_relevant_windows"] = torch.cat(
|
34 |
+
[windows, windows_and_scores[:, 2:3]], dim=1).tolist()
|
35 |
+
line["pred_relevant_windows"] = [e[:2] + [float(f"{e[2]:.4f}")] for e in line["pred_relevant_windows"]]
|
36 |
+
processed_lines.append(line)
|
37 |
+
return processed_lines
|
38 |
+
|
39 |
+
def clip_min_max_timestamps(self, windows):
|
40 |
+
"""
|
41 |
+
windows: (#windows, 2) torch.Tensor
|
42 |
+
ensure timestamps for all windows is within [min_val, max_val], clip is out of boundaries.
|
43 |
+
"""
|
44 |
+
return torch.clamp(windows, min=self.min_ts_val, max=self.max_ts_val)
|
45 |
+
|
46 |
+
def round_to_multiple_clip_lengths(self, windows):
|
47 |
+
"""
|
48 |
+
windows: (#windows, 2) torch.Tensor
|
49 |
+
ensure the final window timestamps are multiples of `clip_length`
|
50 |
+
"""
|
51 |
+
return torch.round(windows / self.clip_length) * self.clip_length
|
52 |
+
|
53 |
+
def clip_window_lengths(self, windows):
|
54 |
+
"""
|
55 |
+
windows: (#windows, 2) np.ndarray
|
56 |
+
ensure the final window duration are within [self.min_w_l, self.max_w_l]
|
57 |
+
"""
|
58 |
+
window_lengths = windows[:, 1] - windows[:, 0]
|
59 |
+
small_rows = window_lengths < self.min_w_l
|
60 |
+
if torch.sum(small_rows) > 0:
|
61 |
+
windows = self.move_windows(
|
62 |
+
windows, small_rows, self.min_w_l, move_method=self.move_window_method)
|
63 |
+
large_rows = window_lengths > self.max_w_l
|
64 |
+
if torch.sum(large_rows) > 0:
|
65 |
+
windows = self.move_windows(
|
66 |
+
windows, large_rows, self.max_w_l, move_method=self.move_window_method)
|
67 |
+
return windows
|
68 |
+
|
69 |
+
@classmethod
|
70 |
+
def move_windows(cls, windows, row_selector, new_length, move_method="left"):
|
71 |
+
"""
|
72 |
+
Args:
|
73 |
+
windows:
|
74 |
+
row_selector:
|
75 |
+
new_length:
|
76 |
+
move_method: str,
|
77 |
+
left: keep left unchanged
|
78 |
+
center: keep center unchanged
|
79 |
+
right: keep right unchanged
|
80 |
+
|
81 |
+
Returns:
|
82 |
+
|
83 |
+
"""
|
84 |
+
# import ipdb;
|
85 |
+
# ipdb.set_trace()
|
86 |
+
if move_method == "left":
|
87 |
+
windows[row_selector, 1] = windows[row_selector, 0] + new_length
|
88 |
+
elif move_method == "right":
|
89 |
+
windows[row_selector, 0] = windows[row_selector, 1] - new_length
|
90 |
+
elif move_method == "center":
|
91 |
+
center = (windows[row_selector, 1] + windows[row_selector, 0]) / 2.
|
92 |
+
windows[row_selector, 0] = center - new_length / 2.
|
93 |
+
windows[row_selector, 1] = center + new_length / 2.
|
94 |
+
return windows
|
95 |
+
|
third_party/cgdetr/cg_detr/scripts/charades_sta/inference.sh
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
ckpt_path=$1
|
2 |
+
eval_split_name=$2
|
3 |
+
eval_path=data/highlight_${eval_split_name}_release.jsonl
|
4 |
+
PYTHONPATH=$PYTHONPATH:. python cg_detr/inference.py \
|
5 |
+
--resume ${ckpt_path} \
|
6 |
+
--eval_split_name ${eval_split_name} \
|
7 |
+
--eval_path ${eval_path} \
|
8 |
+
${@:3}
|
third_party/cgdetr/cg_detr/scripts/charades_sta/train.sh
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
dset_name=charadesSTA
|
2 |
+
ctx_mode=video_tef
|
3 |
+
v_feat_types=intern
|
4 |
+
t_feat_type=intern
|
5 |
+
results_root=results_charades
|
6 |
+
exp_id=exp
|
7 |
+
|
8 |
+
######## data paths
|
9 |
+
train_path=data/charades_sta/charades_sta_train_tvr_format.jsonl
|
10 |
+
eval_path=data/charades_sta/charades_sta_test_tvr_format.jsonl
|
11 |
+
eval_split_name=val
|
12 |
+
|
13 |
+
######## setup video+text features
|
14 |
+
feat_root=/mnt/petrelfs/lizhilin/CGDETR-main/features/charades
|
15 |
+
|
16 |
+
# video features
|
17 |
+
v_feat_dim=0
|
18 |
+
v_feat_dirs=()
|
19 |
+
if [[ ${v_feat_types} == *"slowfast"* ]]; then
|
20 |
+
v_feat_dirs+=(${feat_root}/slowfast_features)
|
21 |
+
(( v_feat_dim += 2304 )) # double brackets for arithmetic op, no need to use ${v_feat_dim}
|
22 |
+
fi
|
23 |
+
if [[ ${v_feat_types} == *"clip"* ]]; then
|
24 |
+
v_feat_dirs+=(${feat_root}/clip_features)
|
25 |
+
(( v_feat_dim += 512 ))
|
26 |
+
fi
|
27 |
+
if [[ ${v_feat_types} == *"intern"* ]]; then
|
28 |
+
v_feat_dirs+=(${feat_root}/charade_sta_internvideo2_videoclip_6b_w1s)
|
29 |
+
(( v_feat_dim += 768 ))
|
30 |
+
fi
|
31 |
+
|
32 |
+
# text features
|
33 |
+
if [[ ${t_feat_type} == "clip" ]]; then
|
34 |
+
t_feat_dir=${feat_root}/clip_text_features/
|
35 |
+
t_feat_dim=512
|
36 |
+
fi
|
37 |
+
if [[ ${t_feat_type} == *"intern"* ]]; then
|
38 |
+
t_feat_dir=(${feat_root}/charade_sta_internvideo2_llama_text_feature)
|
39 |
+
t_feat_dim=4096
|
40 |
+
fi
|
41 |
+
|
42 |
+
#### training
|
43 |
+
bsz=32
|
44 |
+
eval_bsz=32
|
45 |
+
num_dummies=45
|
46 |
+
num_prompts=2
|
47 |
+
total_prompts=10
|
48 |
+
lr_drop=400
|
49 |
+
enc_layers=3
|
50 |
+
dec_layers=3
|
51 |
+
t2v_layers=2
|
52 |
+
dummy_layers=2
|
53 |
+
moment_layers=1
|
54 |
+
sent_layers=1
|
55 |
+
|
56 |
+
PYTHONPATH=$PYTHONPATH:. \
|
57 |
+
srun -p video5 \
|
58 |
+
--preempt \
|
59 |
+
--job-name=${JOB_NAME} \
|
60 |
+
--ntasks=1 \
|
61 |
+
--gres=gpu:1 \
|
62 |
+
--ntasks-per-node=1 \
|
63 |
+
--cpus-per-task=8 \
|
64 |
+
--kill-on-bad-exit=1 \
|
65 |
+
python cg_detr/train.py \
|
66 |
+
--dset_name ${dset_name} \
|
67 |
+
--ctx_mode ${ctx_mode} \
|
68 |
+
--train_path ${train_path} \
|
69 |
+
--eval_path ${eval_path} \
|
70 |
+
--eval_split_name ${eval_split_name} \
|
71 |
+
--v_feat_dirs ${v_feat_dirs[@]} \
|
72 |
+
--v_feat_dim ${v_feat_dim} \
|
73 |
+
--t_feat_dir ${t_feat_dir} \
|
74 |
+
--t_feat_dim ${t_feat_dim} \
|
75 |
+
--bsz ${bsz} \
|
76 |
+
--results_root ${results_root} \
|
77 |
+
--exp_id ${exp_id} \
|
78 |
+
--max_v_l -1 \
|
79 |
+
--clip_length 1 \
|
80 |
+
--lr 0.0002 \
|
81 |
+
--lr_drop ${lr_drop} \
|
82 |
+
--n_epoch 200 \
|
83 |
+
--contrastive_align_loss_coef 0.002 \
|
84 |
+
--lw_saliency 4 \
|
85 |
+
--enc_layers ${enc_layers} \
|
86 |
+
--dec_layers ${dec_layers} \
|
87 |
+
--t2v_layers ${t2v_layers} \
|
88 |
+
--moment_layers ${moment_layers} \
|
89 |
+
--dummy_layers ${dummy_layers} \
|
90 |
+
--sent_layers ${sent_layers} \
|
91 |
+
--eval_bsz ${eval_bsz} \
|
92 |
+
--num_dummies ${num_dummies} \
|
93 |
+
--num_prompts ${num_prompts} \
|
94 |
+
--total_prompts ${total_prompts} \
|
95 |
+
${@:1}
|
third_party/cgdetr/cg_detr/scripts/inference.sh
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
ckpt_path=$1
|
2 |
+
eval_split_name=$2
|
3 |
+
eval_path=data/highlight_${eval_split_name}_release.jsonl
|
4 |
+
echo ${ckpt_path}
|
5 |
+
echo ${eval_split_name}
|
6 |
+
echo ${eval_path}
|
7 |
+
PYTHONPATH=$PYTHONPATH:. python cg_detr/inference.py \
|
8 |
+
--resume ${ckpt_path} \
|
9 |
+
--eval_split_name ${eval_split_name} \
|
10 |
+
--eval_path ${eval_path} \
|
11 |
+
${@:3}
|
third_party/cgdetr/cg_detr/scripts/train.sh
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
dset_name=hl
|
2 |
+
ctx_mode=video_tef
|
3 |
+
v_feat_types=intern
|
4 |
+
t_feat_type=intern
|
5 |
+
results_root=results_qvhighlights
|
6 |
+
exp_id=exp
|
7 |
+
|
8 |
+
######## data paths
|
9 |
+
train_path=data/highlight_train_release.jsonl
|
10 |
+
eval_path=data/highlight_val_release.jsonl
|
11 |
+
eval_split_name=val
|
12 |
+
|
13 |
+
######## setup video+text features
|
14 |
+
feat_root=../features/qvhighlight
|
15 |
+
|
16 |
+
# video features
|
17 |
+
v_feat_dim=0
|
18 |
+
v_feat_dirs=()
|
19 |
+
if [[ ${v_feat_types} == *"slowfast"* ]]; then
|
20 |
+
v_feat_dirs+=(${feat_root}/slowfast_features)
|
21 |
+
(( v_feat_dim += 2304 )) # double brackets for arithmetic op, no need to use ${v_feat_dim}
|
22 |
+
fi
|
23 |
+
if [[ ${v_feat_types} == *"clip"* ]]; then
|
24 |
+
v_feat_dirs+=(${feat_root}/clip_features)
|
25 |
+
(( v_feat_dim += 512 ))
|
26 |
+
fi
|
27 |
+
if [[ ${v_feat_types} == *"intern"* ]]; then
|
28 |
+
v_feat_dirs+=(${feat_root}/qvhighlight_internvideo2_videoclip_6b_w2s)
|
29 |
+
(( v_feat_dim += 768 ))
|
30 |
+
fi
|
31 |
+
|
32 |
+
# text features
|
33 |
+
if [[ ${t_feat_type} == "clip" ]]; then
|
34 |
+
t_feat_dir=${feat_root}/clip_text_features/
|
35 |
+
t_feat_dim=512
|
36 |
+
fi
|
37 |
+
if [[ ${t_feat_type} == *"intern"* ]]; then
|
38 |
+
t_feat_dir=(${feat_root}/qvhighlight_internvideo2_llama_text_feature)
|
39 |
+
t_feat_dim=4096
|
40 |
+
fi
|
41 |
+
|
42 |
+
|
43 |
+
#### training
|
44 |
+
bsz=32
|
45 |
+
enc_layers=3
|
46 |
+
dec_layers=3
|
47 |
+
t2v_layers=2
|
48 |
+
moment_layers=1
|
49 |
+
dummy_layers=2
|
50 |
+
sent_layers=1
|
51 |
+
max_v_l=75
|
52 |
+
max_q_l=32
|
53 |
+
|
54 |
+
PYTHONPATH=$PYTHONPATH:. python cg_detr/train.py \
|
55 |
+
--dset_name ${dset_name} \
|
56 |
+
--ctx_mode ${ctx_mode} \
|
57 |
+
--train_path ${train_path} \
|
58 |
+
--eval_path ${eval_path} \
|
59 |
+
--eval_split_name ${eval_split_name} \
|
60 |
+
--v_feat_dirs ${v_feat_dirs[@]} \
|
61 |
+
--v_feat_dim ${v_feat_dim} \
|
62 |
+
--t_feat_dir ${t_feat_dir} \
|
63 |
+
--t_feat_dim ${t_feat_dim} \
|
64 |
+
--bsz ${bsz} \
|
65 |
+
--lr 0.0002 \
|
66 |
+
--results_root ${results_root} \
|
67 |
+
--exp_id ${exp_id} \
|
68 |
+
--enc_layers ${enc_layers} \
|
69 |
+
--dec_layers ${dec_layers} \
|
70 |
+
--t2v_layers ${t2v_layers} \
|
71 |
+
--moment_layers ${moment_layers} \
|
72 |
+
--dummy_layers ${dummy_layers} \
|
73 |
+
--sent_layers ${sent_layers} \
|
74 |
+
--max_v_l ${max_v_l} \
|
75 |
+
--max_q_l ${max_q_l} \
|
76 |
+
${@:1}
|
third_party/cgdetr/cg_detr/span_utils.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
def span_xx_to_cxw(xx_spans):
|
5 |
+
"""
|
6 |
+
Args:
|
7 |
+
xx_spans: tensor, (#windows, 2) or (..., 2), each row is a window of format (st, ed)
|
8 |
+
|
9 |
+
Returns:
|
10 |
+
cxw_spans: tensor, (#windows, 2), each row is a window of format (center=(st+ed)/2, width=(ed-st))
|
11 |
+
>>> spans = torch.Tensor([[0, 1], [0.2, 0.4]])
|
12 |
+
>>> span_xx_to_cxw(spans)
|
13 |
+
tensor([[0.5000, 1.0000],
|
14 |
+
[0.3000, 0.2000]])
|
15 |
+
>>> spans = torch.Tensor([[[0, 1], [0.2, 0.4]]])
|
16 |
+
>>> span_xx_to_cxw(spans)
|
17 |
+
tensor([[[0.5000, 1.0000],
|
18 |
+
[0.3000, 0.2000]]])
|
19 |
+
"""
|
20 |
+
center = xx_spans.sum(-1) * 0.5
|
21 |
+
width = xx_spans[..., 1] - xx_spans[..., 0]
|
22 |
+
return torch.stack([center, width], dim=-1)
|
23 |
+
|
24 |
+
|
25 |
+
def span_cxw_to_xx(cxw_spans):
|
26 |
+
"""
|
27 |
+
Args:
|
28 |
+
cxw_spans: tensor, (#windows, 2) or (..., 2), the last dim is a row denoting a window of format (center, width)
|
29 |
+
|
30 |
+
>>> spans = torch.Tensor([[0.5000, 1.0000], [0.3000, 0.2000]])
|
31 |
+
>>> span_cxw_to_xx(spans)
|
32 |
+
tensor([[0.0000, 1.0000],
|
33 |
+
[0.2000, 0.4000]])
|
34 |
+
>>> spans = torch.Tensor([[[0.5000, 1.0000], [0.3000, 0.2000]]])
|
35 |
+
>>> span_cxw_to_xx(spans)
|
36 |
+
tensor([[[0.0000, 1.0000],
|
37 |
+
[0.2000, 0.4000]]])
|
38 |
+
"""
|
39 |
+
x1 = cxw_spans[..., 0] - 0.5 * cxw_spans[..., 1]
|
40 |
+
x2 = cxw_spans[..., 0] + 0.5 * cxw_spans[..., 1]
|
41 |
+
return torch.stack([x1, x2], dim=-1)
|
42 |
+
|
43 |
+
|
44 |
+
def temporal_iou(spans1, spans2):
|
45 |
+
"""
|
46 |
+
Args:
|
47 |
+
spans1: (N, 2) torch.Tensor, each row defines a span [st, ed]
|
48 |
+
spans2: (M, 2) torch.Tensor, ...
|
49 |
+
|
50 |
+
Returns:
|
51 |
+
iou: (N, M) torch.Tensor
|
52 |
+
union: (N, M) torch.Tensor
|
53 |
+
>>> test_spans1 = torch.Tensor([[0, 0.2], [0.5, 1.0]])
|
54 |
+
>>> test_spans2 = torch.Tensor([[0, 0.3], [0., 1.0]])
|
55 |
+
>>> temporal_iou(test_spans1, test_spans2)
|
56 |
+
(tensor([[0.6667, 0.2000],
|
57 |
+
[0.0000, 0.5000]]),
|
58 |
+
tensor([[0.3000, 1.0000],
|
59 |
+
[0.8000, 1.0000]]))
|
60 |
+
"""
|
61 |
+
areas1 = spans1[:, 1] - spans1[:, 0] # (N, )
|
62 |
+
areas2 = spans2[:, 1] - spans2[:, 0] # (M, )
|
63 |
+
|
64 |
+
left = torch.max(spans1[:, None, 0], spans2[:, 0]) # (N, M)
|
65 |
+
right = torch.min(spans1[:, None, 1], spans2[:, 1]) # (N, M)
|
66 |
+
|
67 |
+
inter = (right - left).clamp(min=0) # (N, M)
|
68 |
+
union = areas1[:, None] + areas2 - inter # (N, M)
|
69 |
+
|
70 |
+
iou = inter / union
|
71 |
+
return iou, union
|
72 |
+
|
73 |
+
|
74 |
+
def temporal_intersection_over_pred(gt_spans, pred_spans):
|
75 |
+
""" intersection over the second input spans
|
76 |
+
Args:
|
77 |
+
gt_spans: (N, 2),
|
78 |
+
pred_spans: (M, 2)
|
79 |
+
|
80 |
+
Returns:
|
81 |
+
|
82 |
+
"""
|
83 |
+
left = torch.max(gt_spans[:, None, 0], pred_spans[:, 0])
|
84 |
+
right = torch.min(gt_spans[:, None, 1], pred_spans[:, 1])
|
85 |
+
|
86 |
+
inter = (right - left).clamp(min=0) # (N, M)
|
87 |
+
inter_over_pred = inter / (pred_spans[:, 1] - pred_spans[:, 0])
|
88 |
+
return inter_over_pred
|
89 |
+
|
90 |
+
|
91 |
+
def generalized_temporal_iou(spans1, spans2):
|
92 |
+
"""
|
93 |
+
Generalized IoU from https://giou.stanford.edu/
|
94 |
+
Also reference to DETR implementation of generalized_box_iou
|
95 |
+
https://github.com/facebookresearch/detr/blob/master/util/box_ops.py#L40
|
96 |
+
|
97 |
+
Args:
|
98 |
+
spans1: (N, 2) torch.Tensor, each row defines a span in xx format [st, ed]
|
99 |
+
spans2: (M, 2) torch.Tensor, ...
|
100 |
+
|
101 |
+
Returns:
|
102 |
+
giou: (N, M) torch.Tensor
|
103 |
+
|
104 |
+
>>> test_spans1 = torch.Tensor([[0, 0.2], [0.5, 1.0]])
|
105 |
+
>>> test_spans2 = torch.Tensor([[0, 0.3], [0., 1.0]])
|
106 |
+
>>> generalized_temporal_iou(test_spans1, test_spans2)
|
107 |
+
tensor([[ 0.6667, 0.2000],
|
108 |
+
[-0.2000, 0.5000]])
|
109 |
+
"""
|
110 |
+
spans1 = spans1.float()
|
111 |
+
spans2 = spans2.float()
|
112 |
+
|
113 |
+
if (spans1[:, 1] < spans1[:, 0]).all():
|
114 |
+
torch.save({'spans1': spans1.cpu(), 'spans2': spans2.cpu()}, 'test_spans.pt')
|
115 |
+
spans1[:, 1] += 0.0001
|
116 |
+
print(spans1)
|
117 |
+
assert (spans1[:, 1] >= spans1[:, 0]).all()
|
118 |
+
assert (spans2[:, 1] >= spans2[:, 0]).all()
|
119 |
+
iou, union = temporal_iou(spans1, spans2)
|
120 |
+
|
121 |
+
left = torch.min(spans1[:, None, 0], spans2[:, 0]) # (N, M)
|
122 |
+
right = torch.max(spans1[:, None, 1], spans2[:, 1]) # (N, M)
|
123 |
+
enclosing_area = (right - left).clamp(min=0) # (N, M)
|
124 |
+
|
125 |
+
return iou - (enclosing_area - union) / enclosing_area
|
126 |
+
|
127 |
+
|
third_party/cgdetr/cg_detr/start_end_dataset.py
ADDED
@@ -0,0 +1,383 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.utils.data import Dataset
|
3 |
+
import numpy as np
|
4 |
+
from tqdm import tqdm
|
5 |
+
import random
|
6 |
+
import logging
|
7 |
+
from os.path import join, exists
|
8 |
+
from third_party.cgdetr.utils.basic_utils import load_jsonl, l2_normalize_np_array
|
9 |
+
from third_party.cgdetr.utils.tensor_utils import pad_sequences_1d
|
10 |
+
from third_party.cgdetr.cg_detr.span_utils import span_xx_to_cxw
|
11 |
+
# from torchtext import vocab
|
12 |
+
import torch.nn as nn
|
13 |
+
|
14 |
+
logger = logging.getLogger(__name__)
|
15 |
+
|
16 |
+
|
17 |
+
class StartEndDataset(Dataset):
|
18 |
+
Q_FEAT_TYPES = ["pooler_output", "last_hidden_state"]
|
19 |
+
"""One line in data loaded from data_path."
|
20 |
+
{
|
21 |
+
"qid": 7803,
|
22 |
+
"query": "Man in gray top walks from outside to inside.",
|
23 |
+
"duration": 150,
|
24 |
+
"vid": "RoripwjYFp8_360.0_510.0",
|
25 |
+
"relevant_clip_ids": [13, 14, 15, 16, 17],
|
26 |
+
"relevant_windows": [[26, 36]]
|
27 |
+
}
|
28 |
+
"""
|
29 |
+
|
30 |
+
def __init__(self, dset_name, data_path, v_feat_dirs, q_feat_dir,
|
31 |
+
q_feat_type="last_hidden_state",
|
32 |
+
max_q_l=32, max_v_l=75, data_ratio=1.0, ctx_mode="video",
|
33 |
+
normalize_v=True, normalize_t=True, load_labels=True,
|
34 |
+
clip_len=2, max_windows=5, span_loss_type="l1", txt_drop_ratio=0,
|
35 |
+
dset_domain=None):
|
36 |
+
self.dset_name = dset_name
|
37 |
+
self.data_path = data_path
|
38 |
+
self.data_ratio = data_ratio
|
39 |
+
self.v_feat_dirs = v_feat_dirs \
|
40 |
+
if isinstance(v_feat_dirs, list) else [v_feat_dirs]
|
41 |
+
self.q_feat_dir = q_feat_dir
|
42 |
+
self.q_feat_type = q_feat_type
|
43 |
+
if max_v_l == -1:
|
44 |
+
max_v_l = 100000000
|
45 |
+
if max_q_l == -1:
|
46 |
+
max_q_l = 100
|
47 |
+
self.max_q_l = max_q_l
|
48 |
+
self.max_v_l = max_v_l
|
49 |
+
self.ctx_mode = ctx_mode
|
50 |
+
self.use_tef = "tef" in ctx_mode
|
51 |
+
self.use_video = "video" in ctx_mode
|
52 |
+
self.normalize_t = normalize_t
|
53 |
+
self.normalize_v = normalize_v
|
54 |
+
self.load_labels = load_labels
|
55 |
+
self.clip_len = clip_len
|
56 |
+
self.max_windows = max_windows # maximum number of windows to use as labels
|
57 |
+
self.span_loss_type = span_loss_type
|
58 |
+
self.txt_drop_ratio = txt_drop_ratio
|
59 |
+
if "val" in data_path or "test" in data_path:
|
60 |
+
assert txt_drop_ratio == 0
|
61 |
+
|
62 |
+
if self.dset_name == 'hl':
|
63 |
+
self.max_q_l = 32
|
64 |
+
self.max_v_l = 75
|
65 |
+
self.clip_len = 2
|
66 |
+
|
67 |
+
# checks
|
68 |
+
assert q_feat_type in self.Q_FEAT_TYPES
|
69 |
+
|
70 |
+
# data
|
71 |
+
self.data = self.load_data()
|
72 |
+
|
73 |
+
self.use_glove = False
|
74 |
+
self.use_glove = 'vgg' in self.v_feat_dirs[0]
|
75 |
+
|
76 |
+
# if self.dset_name == 'charadesSTA' and self.use_glove:
|
77 |
+
# self.vocab = vocab.pretrained_aliases['glove.6B.300d']()
|
78 |
+
# self.vocab.itos.extend(['<unk>'])
|
79 |
+
# self.vocab.stoi['<unk>'] = self.vocab.vectors.shape[0]
|
80 |
+
# self.vocab.vectors = torch.cat(
|
81 |
+
# (self.vocab.vectors, torch.zeros(1, self.vocab.dim)), dim=0)
|
82 |
+
# self.embedding = nn.Embedding.from_pretrained(self.vocab.vectors)
|
83 |
+
|
84 |
+
|
85 |
+
def load_data(self):
|
86 |
+
datalist = load_jsonl(self.data_path)
|
87 |
+
if self.data_ratio != 1:
|
88 |
+
n_examples = int(len(datalist) * self.data_ratio)
|
89 |
+
datalist = datalist[:n_examples]
|
90 |
+
logger.info("Using {}% of the data: {} examples"
|
91 |
+
.format(self.data_ratio * 100, n_examples))
|
92 |
+
return datalist
|
93 |
+
|
94 |
+
def __len__(self):
|
95 |
+
return len(self.data)
|
96 |
+
|
97 |
+
|
98 |
+
def __getitem__(self, index):
|
99 |
+
meta = self.data[index]
|
100 |
+
|
101 |
+
model_inputs = dict()
|
102 |
+
|
103 |
+
if self.use_glove: # False
|
104 |
+
model_inputs["query_feat"] = self.get_query(meta["query"])
|
105 |
+
else:
|
106 |
+
model_inputs["query_feat"] = self._get_query_feat_by_qid(meta["qid"]) # (Dq, ) or (Lq, Dq) # [16, 4096]
|
107 |
+
|
108 |
+
|
109 |
+
if self.use_video : # True
|
110 |
+
model_inputs["video_feat"] = self._get_video_feat_by_vid(meta["vid"]) # (Lv, Dv)
|
111 |
+
ctx_l = len(model_inputs["video_feat"])
|
112 |
+
else:
|
113 |
+
ctx_l = self.max_v_l
|
114 |
+
|
115 |
+
|
116 |
+
if self.use_tef:
|
117 |
+
tef_st = torch.arange(0, ctx_l, 1.0) / ctx_l
|
118 |
+
tef_ed = tef_st + 1.0 / ctx_l
|
119 |
+
tef = torch.stack([tef_st, tef_ed], dim=1) # (Lv, 2)
|
120 |
+
if self.use_video :
|
121 |
+
model_inputs["video_feat"] = torch.cat(
|
122 |
+
[model_inputs["video_feat"], tef], dim=1) # (Lv, Dv+2)
|
123 |
+
else:
|
124 |
+
model_inputs["video_feat"] = tef
|
125 |
+
|
126 |
+
|
127 |
+
|
128 |
+
if "relevant_windows" in meta: ## For Qvhighlights test set
|
129 |
+
model_inputs["span_labels"] = self.get_span_labels(meta["relevant_windows"], ctx_l) # (#windows, 2)
|
130 |
+
if self.dset_name in ['charadesSTA', 'tacos', 'activitynet']: ## charades, tacos, nlq
|
131 |
+
model_inputs["saliency_pos_labels"], model_inputs["saliency_neg_labels"], model_inputs["saliency_all_labels"] = \
|
132 |
+
self.get_saliency_labels_sub_as_query(meta["relevant_windows"][0], meta["duration"], ctx_l) # only one gt
|
133 |
+
elif "subs_train" not in self.data_path:
|
134 |
+
model_inputs["saliency_pos_labels"], model_inputs["saliency_neg_labels"], model_inputs["saliency_all_labels"] = \
|
135 |
+
self.get_saliency_labels_all(meta["relevant_clip_ids"], meta["saliency_scores"], ctx_l)
|
136 |
+
else:
|
137 |
+
model_inputs["saliency_pos_labels"], model_inputs["saliency_neg_labels"], model_inputs["saliency_all_labels"] = \
|
138 |
+
self.get_saliency_labels_sub_as_query(meta["relevant_windows"][0], meta["duration"], ctx_l) # only one gt
|
139 |
+
|
140 |
+
if 'qvhighlight' or 'qvhl' in self.data_path:
|
141 |
+
model_inputs["relevant_clip_ids"] = meta["relevant_clip_ids"]
|
142 |
+
model_inputs["vid"] = meta["vid"]
|
143 |
+
model_inputs["qid"] = meta["qid"]
|
144 |
+
return dict(meta=meta, model_inputs=model_inputs)
|
145 |
+
|
146 |
+
# def get_query(self, query):
|
147 |
+
# word_inds = torch.LongTensor(
|
148 |
+
# [self.vocab.stoi.get(w.lower(), 400000) for w in query.split()])
|
149 |
+
# return self.embedding(word_inds)
|
150 |
+
def get_query(self, query):
|
151 |
+
print("ERROR")
|
152 |
+
exit()
|
153 |
+
|
154 |
+
def get_saliency_labels_sub_as_query(self, gt_window, duration, ctx_l, max_n=2):
|
155 |
+
clip_len = duration / ctx_l
|
156 |
+
gt_st = int(gt_window[0] / clip_len)
|
157 |
+
gt_ed = max(0, min(int(gt_window[1] / clip_len), ctx_l) - 1)
|
158 |
+
if gt_st > gt_ed:
|
159 |
+
gt_st = gt_ed
|
160 |
+
|
161 |
+
if gt_st != gt_ed:
|
162 |
+
pos_clip_indices = random.sample(range(gt_st, gt_ed + 1), k=max_n) # 在GT frame idx中随机选两个
|
163 |
+
else:
|
164 |
+
if self.dset_name == 'nlq':
|
165 |
+
pos_clip_indices = [gt_st] * 2
|
166 |
+
else:
|
167 |
+
pos_clip_indices = [gt_st, gt_st]
|
168 |
+
|
169 |
+
neg_pool = list(range(0, gt_st)) + list(range(gt_ed+1, ctx_l)) # 非GT的frame idx
|
170 |
+
try:
|
171 |
+
neg_clip_indices = random.sample(neg_pool, k=max_n) # 在非GT frame idx中随机选两个
|
172 |
+
except:
|
173 |
+
neg_clip_indices = pos_clip_indices
|
174 |
+
|
175 |
+
# For charades_sta
|
176 |
+
score_array = np.zeros(ctx_l)
|
177 |
+
score_array[gt_st:gt_ed + 1] = 1
|
178 |
+
|
179 |
+
return pos_clip_indices, neg_clip_indices, score_array
|
180 |
+
|
181 |
+
|
182 |
+
def get_saliency_labels(self, rel_clip_ids, scores, ctx_l, max_n=1, add_easy_negative=True):
|
183 |
+
"""Sum the scores from the three annotations, then take the two clips with the
|
184 |
+
maximum scores as positive, and two with the minimum scores as negative.
|
185 |
+
Args:
|
186 |
+
rel_clip_ids: list(int), list of relevant clip ids
|
187 |
+
scores: list([anno1_score, anno2_score, anno3_score]),
|
188 |
+
ctx_l: int
|
189 |
+
max_n: int, #clips to use as positive and negative, for easy and hard negative, respectively.
|
190 |
+
add_easy_negative: bool, if True, sample eay negative outside the relevant_clip_ids.
|
191 |
+
"""
|
192 |
+
# indices inside rel_clip_ids
|
193 |
+
scores = np.array(scores) # (#rel_clips, 3)
|
194 |
+
agg_scores = np.sum(scores, 1) # (#rel_clips, )
|
195 |
+
sort_indices = np.argsort(agg_scores) # increasing
|
196 |
+
|
197 |
+
# indices in the whole video
|
198 |
+
# the min(_, ctx_l-1) here is incorrect, but should not cause
|
199 |
+
# much troubles since this should be rarely used.
|
200 |
+
hard_pos_clip_indices = [min(rel_clip_ids[idx], ctx_l-1) for idx in sort_indices[-max_n:]]
|
201 |
+
hard_neg_clip_indices = [min(rel_clip_ids[idx], ctx_l-1) for idx in sort_indices[:max_n]]
|
202 |
+
easy_pos_clip_indices = []
|
203 |
+
easy_neg_clip_indices = []
|
204 |
+
if add_easy_negative:
|
205 |
+
easy_neg_pool = list(set(range(ctx_l)) - set(rel_clip_ids))
|
206 |
+
if len(easy_neg_pool) >= max_n:
|
207 |
+
easy_pos_clip_indices = random.sample(rel_clip_ids, k=max_n)
|
208 |
+
easy_neg_clip_indices = random.sample(easy_neg_pool, k=max_n)
|
209 |
+
else: # copy the hard ones
|
210 |
+
easy_pos_clip_indices = hard_pos_clip_indices
|
211 |
+
easy_neg_clip_indices = hard_neg_clip_indices
|
212 |
+
|
213 |
+
pos_clip_indices = hard_pos_clip_indices + easy_pos_clip_indices
|
214 |
+
neg_clip_indices = hard_neg_clip_indices + easy_neg_clip_indices
|
215 |
+
return pos_clip_indices, neg_clip_indices
|
216 |
+
|
217 |
+
def get_saliency_labels_all(self, rel_clip_ids, scores, ctx_l, max_n=1, add_easy_negative=True):
|
218 |
+
"""Sum the scores from the three annotations, then take the two clips with the
|
219 |
+
maximum scores as positive, and two with the minimum scores as negative.
|
220 |
+
Args:
|
221 |
+
rel_clip_ids: list(int), list of relevant clip ids
|
222 |
+
scores: list([anno1_score, anno2_score, anno3_score]),
|
223 |
+
ctx_l: int
|
224 |
+
max_n: int, #clips to use as positive and negative, for easy and hard negative, respectively.
|
225 |
+
add_easy_negative: bool, if True, sample eay negative outside the relevant_clip_ids.
|
226 |
+
"""
|
227 |
+
# indices inside rel_clip_ids
|
228 |
+
scores = np.array(scores) # (#rel_clips, 3)
|
229 |
+
agg_scores = np.sum(scores, 1) # (#rel_clips, )
|
230 |
+
sort_indices = np.argsort(agg_scores) # increasing
|
231 |
+
|
232 |
+
# score_array = [min(agg_scores[idx], ctx_l-1) for idx in range(ctx_l)]
|
233 |
+
score_array = np.zeros(ctx_l)
|
234 |
+
max_len=ctx_l
|
235 |
+
for idx in range(len(rel_clip_ids)):
|
236 |
+
if rel_clip_ids[idx] >= ctx_l:
|
237 |
+
max_len=max(max_len,rel_clip_ids[idx])
|
238 |
+
# score_array_new = np.zeros(ctx_l + 1)
|
239 |
+
score_array_new = np.zeros(max_len+1)
|
240 |
+
# score_array_new[:ctx_l] = score_array
|
241 |
+
score_array_new[:len(score_array)] = score_array
|
242 |
+
score_array = score_array_new
|
243 |
+
score_array[rel_clip_ids[idx]] = agg_scores[idx]
|
244 |
+
|
245 |
+
# indices in the whole video
|
246 |
+
# the min(_, ctx_l-1) here is incorrect, but should not cause
|
247 |
+
# much troubles since this should be rarely used.
|
248 |
+
hard_pos_clip_indices = [min(rel_clip_ids[idx], ctx_l-1) for idx in sort_indices[-max_n:]]
|
249 |
+
hard_neg_clip_indices = [min(rel_clip_ids[idx], ctx_l-1) for idx in sort_indices[:max_n]]
|
250 |
+
easy_pos_clip_indices = []
|
251 |
+
easy_neg_clip_indices = []
|
252 |
+
if add_easy_negative:
|
253 |
+
easy_neg_pool = list(set(range(ctx_l)) - set(rel_clip_ids))
|
254 |
+
if len(easy_neg_pool) >= max_n:
|
255 |
+
easy_pos_clip_indices = random.sample(rel_clip_ids, k=max_n)
|
256 |
+
easy_neg_clip_indices = random.sample(easy_neg_pool, k=max_n)
|
257 |
+
else: # copy the hard ones
|
258 |
+
easy_pos_clip_indices = hard_pos_clip_indices
|
259 |
+
easy_neg_clip_indices = hard_neg_clip_indices
|
260 |
+
|
261 |
+
pos_clip_indices = hard_pos_clip_indices + easy_pos_clip_indices
|
262 |
+
neg_clip_indices = hard_neg_clip_indices + easy_neg_clip_indices
|
263 |
+
return pos_clip_indices, neg_clip_indices, score_array
|
264 |
+
|
265 |
+
def get_span_labels(self, windows, ctx_l):
|
266 |
+
"""
|
267 |
+
windows: list([st, ed]) in seconds. E.g. [[26, 36]], corresponding st_ed clip_indices [[13, 17]] (inclusive)
|
268 |
+
Note a maximum of `self.max_windows` windows are used.
|
269 |
+
returns Tensor of shape (#windows, 2), each row is [center, width] normalized by video length
|
270 |
+
"""
|
271 |
+
if len(windows) > self.max_windows:
|
272 |
+
random.shuffle(windows)
|
273 |
+
windows = windows[:self.max_windows]
|
274 |
+
if self.span_loss_type == "l1":
|
275 |
+
windows = torch.Tensor(windows) / (ctx_l * self.clip_len) # normalized windows in xx
|
276 |
+
windows = span_xx_to_cxw(windows) # normalized windows in cxw
|
277 |
+
elif self.span_loss_type == "ce":
|
278 |
+
windows = torch.Tensor([
|
279 |
+
[int(w[0] / self.clip_len), min(int(w[1] / self.clip_len), ctx_l) - 1]
|
280 |
+
for w in windows]).long() # inclusive
|
281 |
+
else:
|
282 |
+
raise NotImplementedError
|
283 |
+
return windows
|
284 |
+
|
285 |
+
def _get_query_feat_by_qid(self, qid):
|
286 |
+
# QVhighlight dataset
|
287 |
+
q_feat_path = join(self.q_feat_dir, f"qid{qid}.pt")
|
288 |
+
# q_feat = np.load(q_feat_path)[self.q_feat_type].astype(np.float32)
|
289 |
+
q_feat = torch.load(q_feat_path).numpy().astype(np.float32)
|
290 |
+
if self.q_feat_type == "last_hidden_state":
|
291 |
+
q_feat = q_feat[:self.max_q_l]
|
292 |
+
if self.normalize_t:
|
293 |
+
q_feat = l2_normalize_np_array(q_feat)
|
294 |
+
if self.txt_drop_ratio > 0:
|
295 |
+
q_feat = self.random_drop_rows(q_feat)
|
296 |
+
return torch.from_numpy(q_feat) # (D, ) or (Lq, D)
|
297 |
+
|
298 |
+
def random_drop_rows(self, embeddings):
|
299 |
+
"""randomly mask num_drop rows in embeddings to be zero.
|
300 |
+
Args:
|
301 |
+
embeddings: np.ndarray (L, D)
|
302 |
+
"""
|
303 |
+
num_drop_rows = round(len(embeddings) * self.txt_drop_ratio)
|
304 |
+
if num_drop_rows > 0:
|
305 |
+
row_indices = np.random.choice(
|
306 |
+
len(embeddings), size=num_drop_rows, replace=False)
|
307 |
+
embeddings[row_indices] = 0
|
308 |
+
return embeddings
|
309 |
+
|
310 |
+
def _get_video_feat_by_vid(self, vid):
|
311 |
+
v_feat_list = []
|
312 |
+
for _feat_dir in self.v_feat_dirs:
|
313 |
+
try:
|
314 |
+
_feat_path = join(_feat_dir, f"{vid}.pt")
|
315 |
+
_feat = torch.load(_feat_path)["features"][:self.max_v_l].numpy().astype(np.float32)
|
316 |
+
except:
|
317 |
+
_feat_path = join(_feat_dir, f"{vid}.pt")
|
318 |
+
_feat = torch.load(_feat_path)[:self.max_v_l].numpy().astype(np.float32)
|
319 |
+
if self.normalize_v:
|
320 |
+
_feat = l2_normalize_np_array(_feat)
|
321 |
+
v_feat_list.append(_feat)
|
322 |
+
# some features are slightly longer than the others
|
323 |
+
min_len = min([len(e) for e in v_feat_list])
|
324 |
+
v_feat_list = [e[:min_len] for e in v_feat_list]
|
325 |
+
v_feat = np.concatenate(v_feat_list, axis=1) # (vlen=34, 768)
|
326 |
+
return torch.from_numpy(v_feat) # (Lv, D)
|
327 |
+
|
328 |
+
|
329 |
+
|
330 |
+
def start_end_collate(batch):
|
331 |
+
batch_meta = [e["meta"] for e in batch] # seems no need to collate ?
|
332 |
+
|
333 |
+
model_inputs_keys = batch[0]["model_inputs"].keys()
|
334 |
+
batched_data = dict()
|
335 |
+
for k in model_inputs_keys:
|
336 |
+
if k == "span_labels":
|
337 |
+
batched_data[k] = [dict(spans=e["model_inputs"]["span_labels"]) for e in batch]
|
338 |
+
continue
|
339 |
+
if k in ["saliency_pos_labels", "saliency_neg_labels"]:
|
340 |
+
batched_data[k] = torch.LongTensor([e["model_inputs"][k] for e in batch])
|
341 |
+
continue
|
342 |
+
if k == "saliency_all_labels":
|
343 |
+
pad_data, mask_data = pad_sequences_1d([e["model_inputs"][k] for e in batch], dtype=np.float32, fixed_length=None)
|
344 |
+
batched_data[k] = torch.tensor(pad_data, dtype=torch.float32)
|
345 |
+
continue
|
346 |
+
if k == 'qid':
|
347 |
+
batched_data[k] = [e["model_inputs"][k] for e in batch]
|
348 |
+
continue
|
349 |
+
if k == 'vid':
|
350 |
+
batched_data[k] = [e["model_inputs"][k] for e in batch]
|
351 |
+
continue
|
352 |
+
batched_data[k] = pad_sequences_1d(
|
353 |
+
[e["model_inputs"][k] for e in batch], dtype=torch.float32, fixed_length=None)
|
354 |
+
return batch_meta, batched_data
|
355 |
+
|
356 |
+
|
357 |
+
def prepare_batch_inputs(batched_model_inputs):
|
358 |
+
model_inputs = dict(
|
359 |
+
src_txt=batched_model_inputs["query_feat"][0],
|
360 |
+
src_txt_mask=batched_model_inputs["query_feat"][1],
|
361 |
+
src_vid=batched_model_inputs["video_feat"][0],
|
362 |
+
src_vid_mask=batched_model_inputs["video_feat"][1],
|
363 |
+
vid=batched_model_inputs["vid"],
|
364 |
+
qid=batched_model_inputs["qid"],
|
365 |
+
)
|
366 |
+
targets = {}
|
367 |
+
|
368 |
+
# import pdb; pdb.set_trace()
|
369 |
+
|
370 |
+
if "span_labels" in batched_model_inputs:
|
371 |
+
targets["span_labels"] = [
|
372 |
+
dict(spans=e["spans"])
|
373 |
+
for e in batched_model_inputs["span_labels"]
|
374 |
+
]
|
375 |
+
if "saliency_pos_labels" in batched_model_inputs:
|
376 |
+
for name in ["saliency_pos_labels", "saliency_neg_labels"]:
|
377 |
+
targets[name] = batched_model_inputs[name]
|
378 |
+
|
379 |
+
if "saliency_all_labels" in batched_model_inputs:
|
380 |
+
targets["saliency_all_labels"] = batched_model_inputs["saliency_all_labels"]
|
381 |
+
targets["relevant_clips"] = batched_model_inputs["saliency_all_labels"]
|
382 |
+
targets = None if len(targets) == 0 else targets
|
383 |
+
return model_inputs, targets
|
third_party/cgdetr/cg_detr/text_encoder.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from easydict import EasyDict as edict
|
5 |
+
from xml.model_components import BertAttention, TrainablePositionalEncoding
|
6 |
+
|
7 |
+
|
8 |
+
class TextEncoder(nn.Module):
|
9 |
+
def __init__(self, hidden_size, drop, input_drop, nheads, max_position_embeddings):
|
10 |
+
super().__init__()
|
11 |
+
self.transformer_encoder = BertAttention(edict(
|
12 |
+
hidden_size=hidden_size,
|
13 |
+
intermediate_size=hidden_size,
|
14 |
+
hidden_dropout_prob=drop,
|
15 |
+
attention_probs_dropout_prob=drop,
|
16 |
+
num_attention_heads=nheads,
|
17 |
+
))
|
18 |
+
self.pos_embed = TrainablePositionalEncoding(
|
19 |
+
max_position_embeddings=max_position_embeddings,
|
20 |
+
hidden_size=hidden_size,
|
21 |
+
dropout=input_drop,
|
22 |
+
)
|
23 |
+
self.modular_vector_mapping = nn.Linear(hidden_size, 1, bias=False)
|
24 |
+
|
25 |
+
def forward(self, feat, mask):
|
26 |
+
"""
|
27 |
+
Args:
|
28 |
+
feat: (N, L, D=hidden_size)
|
29 |
+
mask: (N, L) with 1 indicates valid
|
30 |
+
|
31 |
+
Returns:
|
32 |
+
(N, D)
|
33 |
+
"""
|
34 |
+
feat = self.pos_embed(feat) # (N, L, D)
|
35 |
+
feat = self.transformer_encoder(feat, mask.unsqueeze(1))
|
36 |
+
att_scores = self.modular_vector_mapping(feat) # (N, L, 1)
|
37 |
+
att_scores = F.softmax(mask_logits(att_scores, mask.unsqueeze(2)), dim=1)
|
38 |
+
pooled_feat = torch.einsum("blm,bld->bmd", att_scores, feat) # (N, 2 or 1, D)
|
39 |
+
return pooled_feat.squeeze(1)
|
40 |
+
|
41 |
+
|
42 |
+
def mask_logits(target, mask):
|
43 |
+
return target * mask + (1 - mask) * (-1e10)
|
44 |
+
|
45 |
+
|
46 |
+
def build_text_encoder(args):
|
47 |
+
return TextEncoder(
|
48 |
+
hidden_size=args.hidden_dim,
|
49 |
+
drop=args.dropout,
|
50 |
+
input_drop=args.input_dropout,
|
51 |
+
nheads=args.nheads,
|
52 |
+
max_position_embeddings=args.max_q_l
|
53 |
+
)
|
third_party/cgdetr/cg_detr/train.py
ADDED
@@ -0,0 +1,283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
import json
|
4 |
+
import pprint
|
5 |
+
import random
|
6 |
+
import numpy as np
|
7 |
+
from tqdm import tqdm, trange
|
8 |
+
from collections import defaultdict
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
import torch.backends.cudnn as cudnn
|
13 |
+
from torch.utils.data import DataLoader
|
14 |
+
from torch.utils.tensorboard import SummaryWriter
|
15 |
+
|
16 |
+
# import sys
|
17 |
+
# print(sys.path)
|
18 |
+
# sys.path.insert(os.getcwd(),0)
|
19 |
+
# print(sys.path)
|
20 |
+
|
21 |
+
from cg_detr.config import BaseOptions
|
22 |
+
from cg_detr.start_end_dataset import StartEndDataset, start_end_collate, prepare_batch_inputs
|
23 |
+
from cg_detr.inference import eval_epoch, start_inference, setup_model
|
24 |
+
from utils.basic_utils import AverageMeter, dict_to_markdown
|
25 |
+
from utils.model_utils import count_parameters
|
26 |
+
|
27 |
+
|
28 |
+
import logging
|
29 |
+
logger = logging.getLogger(__name__)
|
30 |
+
logging.basicConfig(format="%(asctime)s.%(msecs)03d:%(levelname)s:%(name)s - %(message)s",
|
31 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
32 |
+
level=logging.INFO)
|
33 |
+
|
34 |
+
|
35 |
+
def set_seed(seed, use_cuda=True):
|
36 |
+
random.seed(seed)
|
37 |
+
np.random.seed(seed)
|
38 |
+
torch.manual_seed(seed)
|
39 |
+
if use_cuda:
|
40 |
+
torch.cuda.manual_seed_all(seed)
|
41 |
+
|
42 |
+
|
43 |
+
def train_epoch(model, criterion, train_loader, optimizer, opt, epoch_i, tb_writer):
|
44 |
+
logger.info(f'[Epoch {epoch_i+1}]')
|
45 |
+
model.train()
|
46 |
+
criterion.train()
|
47 |
+
|
48 |
+
# init meters
|
49 |
+
time_meters = defaultdict(AverageMeter)
|
50 |
+
loss_meters = defaultdict(AverageMeter)
|
51 |
+
|
52 |
+
num_training_examples = len(train_loader)
|
53 |
+
timer_dataloading = time.time()
|
54 |
+
for batch_idx, batch in tqdm(enumerate(train_loader),
|
55 |
+
desc="Training Iteration",
|
56 |
+
total=num_training_examples):
|
57 |
+
time_meters["dataloading_time"].update(time.time() - timer_dataloading)
|
58 |
+
timer_start = time.time()
|
59 |
+
model_inputs, targets = prepare_batch_inputs(batch[1], opt.device, non_blocking=opt.pin_memory)
|
60 |
+
time_meters["prepare_inputs_time"].update(time.time() - timer_start)
|
61 |
+
timer_start = time.time()
|
62 |
+
|
63 |
+
outputs = model(**model_inputs, targets=targets)
|
64 |
+
loss_dict = criterion(outputs, targets)
|
65 |
+
weight_dict = criterion.weight_dict
|
66 |
+
losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)
|
67 |
+
time_meters["model_forward_time"].update(time.time() - timer_start)
|
68 |
+
|
69 |
+
timer_start = time.time()
|
70 |
+
optimizer.zero_grad()
|
71 |
+
losses.backward()
|
72 |
+
if opt.grad_clip > 0:
|
73 |
+
nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip)
|
74 |
+
optimizer.step()
|
75 |
+
time_meters["model_backward_time"].update(time.time() - timer_start)
|
76 |
+
|
77 |
+
loss_dict["loss_overall"] = float(losses) # for logging only
|
78 |
+
for k, v in loss_dict.items():
|
79 |
+
loss_meters[k].update(float(v) * weight_dict[k] if k in weight_dict else float(v))
|
80 |
+
|
81 |
+
timer_dataloading = time.time()
|
82 |
+
if opt.debug and batch_idx == 3:
|
83 |
+
break
|
84 |
+
|
85 |
+
# print/add logs
|
86 |
+
tb_writer.add_scalar("Train/lr", float(optimizer.param_groups[0]["lr"]), epoch_i+1)
|
87 |
+
for k, v in loss_meters.items():
|
88 |
+
tb_writer.add_scalar("Train/{}".format(k), v.avg, epoch_i+1)
|
89 |
+
|
90 |
+
to_write = opt.train_log_txt_formatter.format(
|
91 |
+
time_str=time.strftime("%Y_%m_%d_%H_%M_%S"),
|
92 |
+
epoch=epoch_i+1,
|
93 |
+
loss_str=" ".join(["{} {:.4f}".format(k, v.avg) for k, v in loss_meters.items()]))
|
94 |
+
with open(opt.train_log_filepath, "a") as f:
|
95 |
+
f.write(to_write)
|
96 |
+
|
97 |
+
logger.info("Epoch time stats:")
|
98 |
+
for name, meter in time_meters.items():
|
99 |
+
d = {k: f"{getattr(meter, k):.4f}" for k in ["max", "min", "avg"]}
|
100 |
+
logger.info(f"{name} ==> {d}")
|
101 |
+
|
102 |
+
|
103 |
+
def train(model, criterion, optimizer, lr_scheduler, train_dataset, val_dataset, opt):
|
104 |
+
if opt.device.type == "cuda":
|
105 |
+
logger.info("CUDA enabled.")
|
106 |
+
model.to(opt.device)
|
107 |
+
|
108 |
+
tb_writer = SummaryWriter(opt.tensorboard_log_dir)
|
109 |
+
tb_writer.add_text("hyperparameters", dict_to_markdown(vars(opt), max_str_len=None))
|
110 |
+
opt.train_log_txt_formatter = "{time_str} [Epoch] {epoch:03d} [Loss] {loss_str}\n"
|
111 |
+
opt.eval_log_txt_formatter = "{time_str} [Epoch] {epoch:03d} [Loss] {loss_str} [Metrics] {eval_metrics_str}\n"
|
112 |
+
|
113 |
+
|
114 |
+
train_loader = DataLoader(
|
115 |
+
train_dataset,
|
116 |
+
collate_fn=start_end_collate,
|
117 |
+
batch_size=opt.bsz,
|
118 |
+
num_workers=opt.num_workers,
|
119 |
+
shuffle=True,
|
120 |
+
pin_memory=opt.pin_memory
|
121 |
+
)
|
122 |
+
|
123 |
+
prev_best_score = 0.
|
124 |
+
es_cnt = 0
|
125 |
+
# start_epoch = 0
|
126 |
+
if opt.start_epoch is None:
|
127 |
+
start_epoch = -1 if opt.eval_untrained else 0
|
128 |
+
else:
|
129 |
+
start_epoch = opt.start_epoch
|
130 |
+
save_submission_filename = "latest_{}_{}_preds.jsonl".format(opt.dset_name, opt.eval_split_name)
|
131 |
+
for epoch_i in trange(start_epoch, opt.n_epoch, desc="Epoch"):
|
132 |
+
if epoch_i > -1:
|
133 |
+
train_epoch(model, criterion, train_loader, optimizer, opt, epoch_i, tb_writer)
|
134 |
+
lr_scheduler.step()
|
135 |
+
eval_epoch_interval = opt.eval_epoch
|
136 |
+
if opt.eval_path is not None and (epoch_i + 1) % eval_epoch_interval == 0:
|
137 |
+
with torch.no_grad():
|
138 |
+
metrics_no_nms, metrics_nms, eval_loss_meters, latest_file_paths = \
|
139 |
+
eval_epoch(model, val_dataset, opt, save_submission_filename, epoch_i, criterion, tb_writer)
|
140 |
+
|
141 |
+
# log
|
142 |
+
to_write = opt.eval_log_txt_formatter.format(
|
143 |
+
time_str=time.strftime("%Y_%m_%d_%H_%M_%S"),
|
144 |
+
epoch=epoch_i,
|
145 |
+
loss_str=" ".join(["{} {:.4f}".format(k, v.avg) for k, v in eval_loss_meters.items()]),
|
146 |
+
eval_metrics_str=json.dumps(metrics_no_nms))
|
147 |
+
|
148 |
+
with open(opt.eval_log_filepath, "a") as f:
|
149 |
+
f.write(to_write)
|
150 |
+
logger.info("metrics_no_nms {}".format(pprint.pformat(metrics_no_nms["brief"], indent=4)))
|
151 |
+
if metrics_nms is not None:
|
152 |
+
logger.info("metrics_nms {}".format(pprint.pformat(metrics_nms["brief"], indent=4)))
|
153 |
+
|
154 |
+
metrics = metrics_no_nms
|
155 |
+
for k, v in metrics["brief"].items():
|
156 |
+
tb_writer.add_scalar(f"Eval/{k}", float(v), epoch_i+1)
|
157 |
+
|
158 |
+
if opt.dset_name in ['hl']:
|
159 |
+
stop_score = metrics["brief"]["MR-full-mAP"]
|
160 |
+
else:
|
161 |
+
stop_score = (metrics["brief"]["[email protected]"] + metrics["brief"]["[email protected]"]) / 2
|
162 |
+
|
163 |
+
|
164 |
+
if stop_score > prev_best_score:
|
165 |
+
es_cnt = 0
|
166 |
+
prev_best_score = stop_score
|
167 |
+
|
168 |
+
checkpoint = {
|
169 |
+
"model": model.state_dict(),
|
170 |
+
"optimizer": optimizer.state_dict(),
|
171 |
+
"lr_scheduler": lr_scheduler.state_dict(),
|
172 |
+
"epoch": epoch_i,
|
173 |
+
"opt": opt
|
174 |
+
}
|
175 |
+
torch.save(checkpoint, opt.ckpt_filepath.replace(".ckpt", "_best.ckpt"))
|
176 |
+
|
177 |
+
best_file_paths = [e.replace("latest", "best") for e in latest_file_paths]
|
178 |
+
for src, tgt in zip(latest_file_paths, best_file_paths):
|
179 |
+
os.renames(src, tgt)
|
180 |
+
logger.info("The checkpoint file has been updated.")
|
181 |
+
else:
|
182 |
+
es_cnt += 1
|
183 |
+
if opt.max_es_cnt != -1 and es_cnt > opt.max_es_cnt: # early stop
|
184 |
+
with open(opt.train_log_filepath, "a") as f:
|
185 |
+
f.write(f"Early Stop at epoch {epoch_i}")
|
186 |
+
logger.info(f"\n>>>>> Early stop at epoch {epoch_i} {prev_best_score}\n")
|
187 |
+
break
|
188 |
+
|
189 |
+
# save ckpt
|
190 |
+
checkpoint = {
|
191 |
+
"model": model.state_dict(),
|
192 |
+
"optimizer": optimizer.state_dict(),
|
193 |
+
"lr_scheduler": lr_scheduler.state_dict(),
|
194 |
+
"epoch": epoch_i,
|
195 |
+
"opt": opt
|
196 |
+
}
|
197 |
+
torch.save(checkpoint, opt.ckpt_filepath.replace(".ckpt", "_latest.ckpt"))
|
198 |
+
|
199 |
+
# save_interval = 10 if "subs_train" in opt.train_path else 50 # smaller for pretrain
|
200 |
+
# if (epoch_i + 1) % save_interval == 0 or (epoch_i + 1) % opt.lr_drop == 0: # additional copies
|
201 |
+
# checkpoint = {
|
202 |
+
# "model": model.state_dict(),
|
203 |
+
# "optimizer": optimizer.state_dict(),
|
204 |
+
# "epoch": epoch_i,
|
205 |
+
# "opt": opt
|
206 |
+
# }
|
207 |
+
# torch.save(checkpoint, opt.ckpt_filepath.replace(".ckpt", f"_e{epoch_i:04d}.ckpt"))
|
208 |
+
|
209 |
+
if opt.debug:
|
210 |
+
break
|
211 |
+
|
212 |
+
tb_writer.close()
|
213 |
+
|
214 |
+
|
215 |
+
|
216 |
+
def start_training():
|
217 |
+
logger.info("Setup config, data and model...")
|
218 |
+
opt = BaseOptions().parse()
|
219 |
+
set_seed(opt.seed)
|
220 |
+
if opt.debug: # keep the model run deterministically
|
221 |
+
# 'cudnn.benchmark = True' enabled auto finding the best algorithm for a specific input/net config.
|
222 |
+
# Enable this only when input size is fixed.
|
223 |
+
cudnn.benchmark = False
|
224 |
+
cudnn.deterministic = True
|
225 |
+
|
226 |
+
|
227 |
+
dataset_config = dict(
|
228 |
+
dset_name=opt.dset_name,
|
229 |
+
data_path=opt.train_path,
|
230 |
+
v_feat_dirs=opt.v_feat_dirs,
|
231 |
+
q_feat_dir=opt.t_feat_dir,
|
232 |
+
q_feat_type="last_hidden_state",
|
233 |
+
max_q_l=opt.max_q_l,
|
234 |
+
max_v_l=opt.max_v_l,
|
235 |
+
ctx_mode=opt.ctx_mode,
|
236 |
+
data_ratio=opt.data_ratio,
|
237 |
+
normalize_v=not opt.no_norm_vfeat,
|
238 |
+
normalize_t=not opt.no_norm_tfeat,
|
239 |
+
clip_len=opt.clip_length,
|
240 |
+
max_windows=opt.max_windows,
|
241 |
+
span_loss_type=opt.span_loss_type,
|
242 |
+
txt_drop_ratio=opt.txt_drop_ratio,
|
243 |
+
dset_domain=opt.dset_domain,
|
244 |
+
)
|
245 |
+
dataset_config["data_path"] = opt.train_path
|
246 |
+
train_dataset = StartEndDataset(**dataset_config)
|
247 |
+
# import pdb; pdb.set_trace()
|
248 |
+
# train_dataset[0]
|
249 |
+
|
250 |
+
if opt.eval_path is not None:
|
251 |
+
dataset_config["data_path"] = opt.eval_path
|
252 |
+
dataset_config["txt_drop_ratio"] = 0
|
253 |
+
dataset_config["q_feat_dir"] = opt.t_feat_dir.replace("sub_features", "text_features") # for pretraining
|
254 |
+
# dataset_config["load_labels"] = False # uncomment to calculate eval loss
|
255 |
+
|
256 |
+
eval_dataset = StartEndDataset(**dataset_config)
|
257 |
+
|
258 |
+
else:
|
259 |
+
eval_dataset = None
|
260 |
+
|
261 |
+
model, criterion, optimizer, lr_scheduler = setup_model(opt)
|
262 |
+
logger.info(f"Model {model}")
|
263 |
+
count_parameters(model)
|
264 |
+
logger.info("Start Training...")
|
265 |
+
|
266 |
+
train(model, criterion, optimizer, lr_scheduler, train_dataset, eval_dataset, opt)
|
267 |
+
|
268 |
+
return opt.ckpt_filepath.replace(".ckpt", "_best.ckpt"), opt.eval_split_name, opt.eval_path, opt.debug, opt
|
269 |
+
|
270 |
+
|
271 |
+
if __name__ == '__main__':
|
272 |
+
best_ckpt_path, eval_split_name, eval_path, debug, opt = start_training()
|
273 |
+
if not debug:
|
274 |
+
input_args = ["--resume", best_ckpt_path,
|
275 |
+
"--eval_split_name", eval_split_name,
|
276 |
+
"--eval_path", eval_path]
|
277 |
+
|
278 |
+
import sys
|
279 |
+
sys.argv[1:] = input_args
|
280 |
+
logger.info("\n\n\nFINISHED TRAINING!!!")
|
281 |
+
logger.info("Evaluating model at {}".format(best_ckpt_path))
|
282 |
+
logger.info("Input args {}".format(sys.argv[1:]))
|
283 |
+
start_inference(opt)
|
third_party/cgdetr/cg_detr/transformer.py
ADDED
@@ -0,0 +1,871 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
2 |
+
"""
|
3 |
+
DETR Transformer class.
|
4 |
+
|
5 |
+
Copy-paste from torch.nn.Transformer with modifications:
|
6 |
+
* positional encodings are passed in MHattention
|
7 |
+
* extra LN at the end of encoder is removed
|
8 |
+
* decoder returns a stack of activations from all decoding layers
|
9 |
+
"""
|
10 |
+
import copy
|
11 |
+
from typing import Optional
|
12 |
+
import torch
|
13 |
+
import torch.nn.functional as F
|
14 |
+
from torch import nn, Tensor
|
15 |
+
import math
|
16 |
+
import numpy as np
|
17 |
+
from .attention import MultiheadAttention
|
18 |
+
from .crossattention import MultiheadAttention as cateattention
|
19 |
+
|
20 |
+
class MLP(nn.Module):
|
21 |
+
""" Very simple multi-layer perceptron (also called FFN)"""
|
22 |
+
|
23 |
+
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
|
24 |
+
super().__init__()
|
25 |
+
self.num_layers = num_layers
|
26 |
+
h = [hidden_dim] * (num_layers - 1)
|
27 |
+
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
|
28 |
+
|
29 |
+
def forward(self, x):
|
30 |
+
for i, layer in enumerate(self.layers):
|
31 |
+
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
32 |
+
return x
|
33 |
+
|
34 |
+
def inverse_sigmoid(x, eps=1e-3):
|
35 |
+
x = x.clamp(min=0, max=1)
|
36 |
+
x1 = x.clamp(min=eps)
|
37 |
+
x2 = (1 - x).clamp(min=eps)
|
38 |
+
return torch.log(x1/x2)
|
39 |
+
|
40 |
+
def gen_sineembed_for_position(pos_tensor, d_model):
|
41 |
+
# n_query, bs, _ = pos_tensor.size()
|
42 |
+
# sineembed_tensor = torch.zeros(n_query, bs, 256)
|
43 |
+
scale = 2 * math.pi
|
44 |
+
dim_t = torch.arange(d_model//2, dtype=torch.float32, device=pos_tensor.device)
|
45 |
+
dim_t = 10000 ** (2 * (dim_t // 2) / (d_model//2))
|
46 |
+
center_embed = pos_tensor[:, :, 0] * scale
|
47 |
+
pos_x = center_embed[:, :, None] / dim_t
|
48 |
+
pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2)
|
49 |
+
|
50 |
+
span_embed = pos_tensor[:, :, 1] * scale
|
51 |
+
pos_w = span_embed[:, :, None] / dim_t
|
52 |
+
pos_w = torch.stack((pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3).flatten(2)
|
53 |
+
|
54 |
+
pos = torch.cat((pos_x, pos_w), dim=2)
|
55 |
+
return pos
|
56 |
+
|
57 |
+
class Transformer(nn.Module):
|
58 |
+
|
59 |
+
def __init__(self, d_model=512, nhead=8, num_queries=2, num_encoder_layers=6,
|
60 |
+
num_decoder_layers=6, dim_feedforward=2048, dropout=0.1,
|
61 |
+
activation="relu", normalize_before=False,
|
62 |
+
return_intermediate_dec=False, query_dim=2,
|
63 |
+
keep_query_pos=False, query_scale_type='cond_elewise',
|
64 |
+
num_patterns=0,
|
65 |
+
modulate_t_attn=True,
|
66 |
+
bbox_embed_diff_each_layer=False, args=None
|
67 |
+
):
|
68 |
+
super().__init__()
|
69 |
+
self.args = args
|
70 |
+
mcls_encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,
|
71 |
+
dropout, activation, normalize_before)
|
72 |
+
mcls_encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
|
73 |
+
self.mcls_encoder = TransformerEncoder(mcls_encoder_layer, args.moment_layers, mcls_encoder_norm)
|
74 |
+
|
75 |
+
t2v_encoder_layer = T2V_TransformerEncoderLayer(d_model, nhead, dim_feedforward,
|
76 |
+
dropout, activation, normalize_before, self.args.num_dummies)
|
77 |
+
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
|
78 |
+
self.t2v_encoder = TransformerCATEEncoder(t2v_encoder_layer, args.t2v_layers, encoder_norm)
|
79 |
+
|
80 |
+
encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,
|
81 |
+
dropout, activation, normalize_before)
|
82 |
+
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
|
83 |
+
self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
|
84 |
+
|
85 |
+
decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward,
|
86 |
+
dropout, activation, normalize_before, keep_query_pos=keep_query_pos)
|
87 |
+
decoder_norm = nn.LayerNorm(d_model)
|
88 |
+
self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm,
|
89 |
+
return_intermediate=return_intermediate_dec,
|
90 |
+
d_model=d_model, query_dim=query_dim, keep_query_pos=keep_query_pos, query_scale_type=query_scale_type,
|
91 |
+
modulate_t_attn=modulate_t_attn,
|
92 |
+
bbox_embed_diff_each_layer=bbox_embed_diff_each_layer)
|
93 |
+
|
94 |
+
self._reset_parameters()
|
95 |
+
|
96 |
+
self.d_model = d_model
|
97 |
+
self.nhead = nhead
|
98 |
+
self.dec_layers = num_decoder_layers
|
99 |
+
self.num_queries = num_queries
|
100 |
+
self.num_patterns = num_patterns
|
101 |
+
|
102 |
+
def _reset_parameters(self):
|
103 |
+
for p in self.parameters():
|
104 |
+
if p.dim() > 1:
|
105 |
+
nn.init.xavier_uniform_(p)
|
106 |
+
|
107 |
+
def forward(self, src, mask, query_embed, pos_embed, video_length=None, moment_idx=None, msrc=None, mpos=None, mmask=None,
|
108 |
+
nmsrc=None, nmpos=None, nmmask=None,
|
109 |
+
ctxtoken=None, gtoken=None, gpos=None, vlen=None):
|
110 |
+
"""
|
111 |
+
Args:
|
112 |
+
src: (batch_size, L, d)
|
113 |
+
mask: (batch_size, L)
|
114 |
+
query_embed: (#queries, d)
|
115 |
+
pos_embed: (batch_size, L, d) the same as src
|
116 |
+
video length: feature shape
|
117 |
+
vlen: actual video length
|
118 |
+
Returns:
|
119 |
+
"""
|
120 |
+
# moment token
|
121 |
+
device = ctxtoken.device
|
122 |
+
if msrc is not None:
|
123 |
+
msrc = msrc.permute(1, 0, 2) # (L, batch_size, d)
|
124 |
+
mpos = mpos.permute(1, 0, 2) # (L, batch_size, d)
|
125 |
+
mmemory = self.mcls_encoder(msrc, src_key_padding_mask=mmask, pos=mpos) # (L, batch_size, d)
|
126 |
+
mmemory_moment, mmemory_frames = mmemory[0], mmemory[1:]
|
127 |
+
else:
|
128 |
+
mmemory_moment = None
|
129 |
+
mmemory_frames = None
|
130 |
+
if nmsrc is not None:
|
131 |
+
nmsrc = nmsrc.permute(1, 0, 2) # (L, batch_size, d)
|
132 |
+
nmpos = nmpos.permute(1, 0, 2) # (L, batch_size, d)
|
133 |
+
nmmemory = self.mcls_encoder(nmsrc, src_key_padding_mask=nmmask, pos=nmpos) # (L, batch_size, d)
|
134 |
+
nmmemory_moment, nmmemory_frames = nmmemory[0], nmmemory[1:]
|
135 |
+
else:
|
136 |
+
nmmemory_moment = None
|
137 |
+
nmmemory_frames = None
|
138 |
+
|
139 |
+
# flatten NxCxHxW to HWxNxC
|
140 |
+
bs, l, d = src.shape
|
141 |
+
src = src.permute(1, 0, 2) # (L, batch_size, d)
|
142 |
+
pos_embed = pos_embed.permute(1, 0, 2) # (L, batch_size, d)
|
143 |
+
refpoint_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) # (#queries, batch_size, d)
|
144 |
+
|
145 |
+
# import pdb; pdb.set_trace()
|
146 |
+
# print(src.dtype)
|
147 |
+
t2v_src, attn_weights = self.t2v_encoder(src, src_key_padding_mask=mask, pos=pos_embed, video_length=video_length) # (L, batch_size, d)
|
148 |
+
|
149 |
+
# Saliency Token
|
150 |
+
## Context
|
151 |
+
ctx_src_ = ctxtoken.permute(1, 0, 2) # L b d
|
152 |
+
|
153 |
+
## Distribution Token with 10 prompt tokens
|
154 |
+
### Video Clip featre - context (avg) --> Find top 10 similar tokens --> weighted sum
|
155 |
+
# import pdb; pdb.set_trace()
|
156 |
+
fr_token_sim = torch.softmax(torch.matmul(F.normalize((src[:video_length] - ctx_src_).permute(1, 0, 2), dim=2), F.normalize(gtoken, dim=1).T), dim=-1)# src : b 75 d, token : 10 x d --> b 75 10
|
157 |
+
### Calculate clip importance
|
158 |
+
frame_importance = attn_weights[:, :, self.args.num_dummies:].sum(2).clone().detach() # b 75
|
159 |
+
### Masking empty clips
|
160 |
+
for i in range(len(frame_importance)):
|
161 |
+
frame_importance[i][vlen[i]:] *= 0.
|
162 |
+
### Normalize
|
163 |
+
frame_importance = (frame_importance / frame_importance.sum(1).unsqueeze(1)) * frame_importance.size(1) # b 75
|
164 |
+
### Scale the similarity with importance
|
165 |
+
fr_token_sim = fr_token_sim * frame_importance.unsqueeze(2).repeat(1, 1, fr_token_sim.size(2)) # b 75 10
|
166 |
+
fr_token_sim = fr_token_sim.mean(1) # b 10
|
167 |
+
topk_val, topkidx = torch.topk(fr_token_sim, k=self.args.num_prompts, dim=1)
|
168 |
+
src_ = torch.zeros((len(fr_token_sim), self.d_model), dtype=torch.bfloat16).to(device)
|
169 |
+
for i in range(len(fr_token_sim)):
|
170 |
+
src_[i] = (topk_val[i].unsqueeze(1) * gtoken[topkidx[i]]).sum(0)
|
171 |
+
src_ = src_.reshape(1, src.size(1), -1)
|
172 |
+
|
173 |
+
## Add context and distribution token
|
174 |
+
src_ = src_ + ctx_src_
|
175 |
+
pos_ = gpos.reshape([1, 1, self.d_model]).repeat(1, pos_embed.shape[1], 1)
|
176 |
+
mask_ = torch.tensor([[False]]).to(mask.device).repeat(mask.shape[0], 1)
|
177 |
+
|
178 |
+
# import pdb; pdb.set_trace()
|
179 |
+
src_, _ = self.t2v_encoder(src_, src_key_padding_mask=mask_, pos=pos_,
|
180 |
+
video_length=video_length, dummy=False) # (L, batch_size, d)
|
181 |
+
|
182 |
+
src = torch.cat([src_, t2v_src], dim=0)
|
183 |
+
mask = torch.cat([mask_, mask], dim=1)
|
184 |
+
pos_embed = torch.cat([pos_, pos_embed], dim=0)
|
185 |
+
|
186 |
+
src = src[:video_length + 1]
|
187 |
+
mask = mask[:, :video_length + 1]
|
188 |
+
pos_embed = pos_embed[:video_length + 1]
|
189 |
+
|
190 |
+
memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) # (L, batch_size, d)
|
191 |
+
memory_global, memory_local = memory[0], memory[1:]
|
192 |
+
memory_local += memory_global.unsqueeze(0).repeat(memory_local.size(0), 1, 1)
|
193 |
+
mask_local = mask[:, 1:]
|
194 |
+
pos_embed_local = pos_embed[1:]
|
195 |
+
|
196 |
+
tgt = torch.zeros(refpoint_embed.shape[0], bs, d).to(device)
|
197 |
+
tgt = tgt.type(torch.bfloat16)
|
198 |
+
|
199 |
+
# import pdb; pdb.set_trace()
|
200 |
+
hs, references = self.decoder(tgt, memory_local, memory_key_padding_mask=mask_local, pos=pos_embed_local, refpoints_unsigmoid=refpoint_embed) # (#layers, #queries, batch_size, d)
|
201 |
+
memory_local = memory_local.transpose(0, 1) # (batch_size, L, d)
|
202 |
+
|
203 |
+
return hs, references, memory_local, memory_global, attn_weights, mmemory_moment, nmmemory_moment, mmemory_frames, nmmemory_frames
|
204 |
+
|
205 |
+
|
206 |
+
class TransformerCATEEncoder(nn.Module):
|
207 |
+
def __init__(self, encoder_layer, num_layers, norm=None, return_intermediate=False):
|
208 |
+
super().__init__()
|
209 |
+
self.layers = _get_clones(encoder_layer, num_layers)
|
210 |
+
self.num_layers = num_layers
|
211 |
+
self.norm = norm
|
212 |
+
self.return_intermediate = return_intermediate
|
213 |
+
|
214 |
+
def forward(self, src,
|
215 |
+
mask: Optional[Tensor] = None,
|
216 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
217 |
+
pos: Optional[Tensor] = None,
|
218 |
+
dummy=True,
|
219 |
+
**kwargs):
|
220 |
+
output = src
|
221 |
+
|
222 |
+
intermediate = []
|
223 |
+
attn_weights = None
|
224 |
+
for i, layer in enumerate(self.layers):
|
225 |
+
output, attn_weight = layer(output, src_mask=mask,
|
226 |
+
src_key_padding_mask=src_key_padding_mask, pos=pos, dummy=dummy, **kwargs)
|
227 |
+
if attn_weights is None:
|
228 |
+
attn_weights = attn_weight
|
229 |
+
else:
|
230 |
+
attn_weights = attn_weights + attn_weight
|
231 |
+
if self.return_intermediate:
|
232 |
+
intermediate.append(output)
|
233 |
+
attn_weights /= self.num_layers
|
234 |
+
|
235 |
+
if self.norm is not None:
|
236 |
+
output = self.norm(output)
|
237 |
+
|
238 |
+
if self.return_intermediate:
|
239 |
+
return torch.stack(intermediate)
|
240 |
+
|
241 |
+
return output, attn_weights
|
242 |
+
|
243 |
+
class TransformerEncoder(nn.Module):
|
244 |
+
|
245 |
+
def __init__(self, encoder_layer, num_layers, norm=None, return_intermediate=False):
|
246 |
+
super().__init__()
|
247 |
+
self.layers = _get_clones(encoder_layer, num_layers)
|
248 |
+
self.num_layers = num_layers
|
249 |
+
self.norm = norm
|
250 |
+
self.return_intermediate = return_intermediate
|
251 |
+
|
252 |
+
def forward(self, src,
|
253 |
+
mask: Optional[Tensor] = None,
|
254 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
255 |
+
pos: Optional[Tensor] = None,
|
256 |
+
**kwargs):
|
257 |
+
output = src
|
258 |
+
|
259 |
+
intermediate = []
|
260 |
+
|
261 |
+
for layer in self.layers:
|
262 |
+
output = layer(output, src_mask=mask,
|
263 |
+
src_key_padding_mask=src_key_padding_mask, pos=pos, **kwargs)
|
264 |
+
if self.return_intermediate:
|
265 |
+
intermediate.append(output)
|
266 |
+
|
267 |
+
if self.norm is not None:
|
268 |
+
output = self.norm(output)
|
269 |
+
|
270 |
+
if self.return_intermediate:
|
271 |
+
return torch.stack(intermediate)
|
272 |
+
|
273 |
+
return output
|
274 |
+
|
275 |
+
|
276 |
+
class TransformerDecoder(nn.Module):
|
277 |
+
|
278 |
+
def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False,
|
279 |
+
d_model=256, query_dim=2, keep_query_pos=False, query_scale_type='cond_elewise',
|
280 |
+
modulate_t_attn=False,
|
281 |
+
bbox_embed_diff_each_layer=False,
|
282 |
+
):
|
283 |
+
super().__init__()
|
284 |
+
self.layers = _get_clones(decoder_layer, num_layers)
|
285 |
+
self.num_layers = num_layers
|
286 |
+
self.norm = norm
|
287 |
+
self.return_intermediate = return_intermediate
|
288 |
+
assert return_intermediate
|
289 |
+
self.query_dim = query_dim
|
290 |
+
|
291 |
+
assert query_scale_type in ['cond_elewise', 'cond_scalar', 'fix_elewise']
|
292 |
+
self.query_scale_type = query_scale_type
|
293 |
+
if query_scale_type == 'cond_elewise':
|
294 |
+
self.query_scale = MLP(d_model, d_model, d_model, 2)
|
295 |
+
elif query_scale_type == 'cond_scalar':
|
296 |
+
self.query_scale = MLP(d_model, d_model, 1, 2)
|
297 |
+
elif query_scale_type == 'fix_elewise':
|
298 |
+
self.query_scale = nn.Embedding(num_layers, d_model)
|
299 |
+
else:
|
300 |
+
raise NotImplementedError("Unknown query_scale_type: {}".format(query_scale_type))
|
301 |
+
|
302 |
+
self.ref_point_head = MLP(d_model, d_model, d_model, 2)
|
303 |
+
|
304 |
+
# self.bbox_embed = None
|
305 |
+
# for DAB-detr
|
306 |
+
if bbox_embed_diff_each_layer:
|
307 |
+
self.bbox_embed = nn.ModuleList([MLP(d_model, d_model, 2, 3) for i in range(num_layers)])
|
308 |
+
else:
|
309 |
+
self.bbox_embed = MLP(d_model, d_model, 2, 3)
|
310 |
+
# init bbox_embed
|
311 |
+
if bbox_embed_diff_each_layer:
|
312 |
+
for bbox_embed in self.bbox_embed:
|
313 |
+
nn.init.constant_(bbox_embed.layers[-1].weight.data, 0)
|
314 |
+
nn.init.constant_(bbox_embed.layers[-1].bias.data, 0)
|
315 |
+
else:
|
316 |
+
nn.init.constant_(self.bbox_embed.layers[-1].weight.data, 0)
|
317 |
+
nn.init.constant_(self.bbox_embed.layers[-1].bias.data, 0)
|
318 |
+
self.d_model = d_model
|
319 |
+
self.modulate_t_attn = modulate_t_attn
|
320 |
+
self.bbox_embed_diff_each_layer = bbox_embed_diff_each_layer
|
321 |
+
|
322 |
+
if modulate_t_attn:
|
323 |
+
self.ref_anchor_head = MLP(d_model, d_model, 1, 2)
|
324 |
+
|
325 |
+
if not keep_query_pos:
|
326 |
+
for layer_id in range(num_layers - 1):
|
327 |
+
self.layers[layer_id + 1].ca_qpos_proj = None
|
328 |
+
|
329 |
+
def forward(self, tgt, memory,
|
330 |
+
tgt_mask: Optional[Tensor] = None,
|
331 |
+
memory_mask: Optional[Tensor] = None,
|
332 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
333 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
334 |
+
pos: Optional[Tensor] = None,
|
335 |
+
refpoints_unsigmoid: Optional[Tensor] = None, # num_queries, bs, 2
|
336 |
+
):
|
337 |
+
output = tgt
|
338 |
+
|
339 |
+
intermediate = []
|
340 |
+
reference_points = refpoints_unsigmoid.sigmoid()
|
341 |
+
ref_points = [reference_points]
|
342 |
+
|
343 |
+
# import pdb; pdb.set_trace()
|
344 |
+
|
345 |
+
for layer_id, layer in enumerate(self.layers):
|
346 |
+
obj_center = reference_points[..., :self.query_dim]
|
347 |
+
# get sine embedding for the query vector
|
348 |
+
query_sine_embed = gen_sineembed_for_position(obj_center, self.d_model)
|
349 |
+
query_sine_embed = query_sine_embed.type(torch.bfloat16)
|
350 |
+
|
351 |
+
query_pos = self.ref_point_head(query_sine_embed)
|
352 |
+
# For the first decoder layer, we do not apply transformation over p_s
|
353 |
+
if self.query_scale_type != 'fix_elewise':
|
354 |
+
if layer_id == 0:
|
355 |
+
pos_transformation = 1
|
356 |
+
else:
|
357 |
+
pos_transformation = self.query_scale(output)
|
358 |
+
else:
|
359 |
+
pos_transformation = self.query_scale.weight[layer_id]
|
360 |
+
|
361 |
+
# apply transformation
|
362 |
+
query_sine_embed = query_sine_embed * pos_transformation
|
363 |
+
|
364 |
+
# modulated HW attentions
|
365 |
+
if self.modulate_t_attn:
|
366 |
+
reft_cond = self.ref_anchor_head(output).sigmoid() # nq, bs, 1
|
367 |
+
|
368 |
+
query_sine_embed *= (reft_cond[..., 0] / obj_center[..., 1]).unsqueeze(-1)
|
369 |
+
|
370 |
+
|
371 |
+
output = layer(output, memory, tgt_mask=tgt_mask,
|
372 |
+
memory_mask=memory_mask,
|
373 |
+
tgt_key_padding_mask=tgt_key_padding_mask,
|
374 |
+
memory_key_padding_mask=memory_key_padding_mask,
|
375 |
+
pos=pos, query_pos=query_pos, query_sine_embed=query_sine_embed,
|
376 |
+
is_first=(layer_id == 0))
|
377 |
+
|
378 |
+
# iter update
|
379 |
+
if self.bbox_embed is not None:
|
380 |
+
if self.bbox_embed_diff_each_layer:
|
381 |
+
tmp = self.bbox_embed[layer_id](output)
|
382 |
+
else:
|
383 |
+
tmp = self.bbox_embed(output)
|
384 |
+
# import ipdb; ipdb.set_trace()
|
385 |
+
tmp[..., :self.query_dim] += inverse_sigmoid(reference_points)
|
386 |
+
new_reference_points = tmp[..., :self.query_dim].sigmoid()
|
387 |
+
if layer_id != self.num_layers - 1:
|
388 |
+
ref_points.append(new_reference_points)
|
389 |
+
reference_points = new_reference_points.detach()
|
390 |
+
|
391 |
+
if self.return_intermediate:
|
392 |
+
intermediate.append(self.norm(output))
|
393 |
+
|
394 |
+
if self.norm is not None:
|
395 |
+
output = self.norm(output)
|
396 |
+
if self.return_intermediate:
|
397 |
+
intermediate.pop()
|
398 |
+
intermediate.append(output)
|
399 |
+
|
400 |
+
if self.return_intermediate:
|
401 |
+
if self.bbox_embed is not None:
|
402 |
+
return [
|
403 |
+
torch.stack(intermediate).transpose(1, 2),
|
404 |
+
torch.stack(ref_points).transpose(1, 2),
|
405 |
+
]
|
406 |
+
else:
|
407 |
+
return [
|
408 |
+
torch.stack(intermediate).transpose(1, 2),
|
409 |
+
reference_points.unsqueeze(0).transpose(1, 2)
|
410 |
+
]
|
411 |
+
|
412 |
+
return output.unsqueeze(0)
|
413 |
+
|
414 |
+
|
415 |
+
class TransformerEncoderLayerThin(nn.Module):
|
416 |
+
|
417 |
+
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
|
418 |
+
activation="relu", normalize_before=False):
|
419 |
+
super().__init__()
|
420 |
+
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
421 |
+
# Implementation of Feedforward model
|
422 |
+
# self.linear1 = nn.Linear(d_model, dim_feedforward)
|
423 |
+
# self.dropout = nn.Dropout(dropout)
|
424 |
+
# self.linear2 = nn.Linear(dim_feedforward, d_model)
|
425 |
+
self.linear = nn.Linear(d_model, d_model)
|
426 |
+
self.norm = nn.LayerNorm(d_model)
|
427 |
+
self.dropout = nn.Dropout(dropout)
|
428 |
+
|
429 |
+
# self.activation = _get_activation_fn(activation)
|
430 |
+
self.normalize_before = normalize_before
|
431 |
+
|
432 |
+
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
433 |
+
return tensor if pos is None else tensor + pos
|
434 |
+
|
435 |
+
def forward_post(self,
|
436 |
+
src,
|
437 |
+
src_mask: Optional[Tensor] = None,
|
438 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
439 |
+
pos: Optional[Tensor] = None):
|
440 |
+
q = k = self.with_pos_embed(src, pos)
|
441 |
+
src2 = self.self_attn(q, k, value=src, attn_mask=src_mask,
|
442 |
+
key_padding_mask=src_key_padding_mask)[0]
|
443 |
+
src2 = self.linear(src2)
|
444 |
+
src = src + self.dropout(src2)
|
445 |
+
src = self.norm(src)
|
446 |
+
# src = src + self.dropout1(src2)
|
447 |
+
# src = self.norm1(src)
|
448 |
+
# src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
|
449 |
+
# src = src + self.dropout2(src2)
|
450 |
+
# src = self.norm2(src)
|
451 |
+
return src
|
452 |
+
|
453 |
+
def forward_pre(self, src,
|
454 |
+
src_mask: Optional[Tensor] = None,
|
455 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
456 |
+
pos: Optional[Tensor] = None):
|
457 |
+
"""not used"""
|
458 |
+
src2 = self.norm1(src)
|
459 |
+
q = k = self.with_pos_embed(src2, pos)
|
460 |
+
src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask,
|
461 |
+
key_padding_mask=src_key_padding_mask)[0]
|
462 |
+
src = src + self.dropout1(src2)
|
463 |
+
src2 = self.norm2(src)
|
464 |
+
src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
|
465 |
+
src = src + self.dropout2(src2)
|
466 |
+
return src
|
467 |
+
|
468 |
+
def forward(self, src,
|
469 |
+
src_mask: Optional[Tensor] = None,
|
470 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
471 |
+
pos: Optional[Tensor] = None):
|
472 |
+
if self.normalize_before:
|
473 |
+
return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
|
474 |
+
return self.forward_post(src, src_mask, src_key_padding_mask, pos)
|
475 |
+
|
476 |
+
|
477 |
+
class T2V_TransformerEncoderLayer(nn.Module):
|
478 |
+
|
479 |
+
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
|
480 |
+
activation="relu", normalize_before=False, num_dummies=3):
|
481 |
+
super().__init__()
|
482 |
+
self.self_attn = cateattention(d_model, nhead, dropout=dropout, num_dummies=num_dummies)
|
483 |
+
# Implementation of Feedforward model
|
484 |
+
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
485 |
+
self.dropout = nn.Dropout(dropout)
|
486 |
+
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
487 |
+
|
488 |
+
self.norm1 = nn.LayerNorm(d_model)
|
489 |
+
self.norm2 = nn.LayerNorm(d_model)
|
490 |
+
self.dropout1 = DropPath(dropout)
|
491 |
+
self.dropout2 = DropPath(dropout)
|
492 |
+
|
493 |
+
self.activation = _get_activation_fn(activation)
|
494 |
+
self.normalize_before = normalize_before
|
495 |
+
self.nhead = nhead
|
496 |
+
|
497 |
+
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
498 |
+
return tensor if pos is None else tensor + pos
|
499 |
+
|
500 |
+
def forward_post(self,
|
501 |
+
src,
|
502 |
+
src_mask: Optional[Tensor] = None,
|
503 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
504 |
+
pos: Optional[Tensor] = None,
|
505 |
+
video_length=None, dummy=True):
|
506 |
+
assert video_length is not None
|
507 |
+
pos_src = self.with_pos_embed(src, pos)
|
508 |
+
q, k, v = pos_src[:video_length], pos_src[video_length:], src[video_length:]
|
509 |
+
|
510 |
+
qmask, kmask = src_key_padding_mask[:, :video_length].unsqueeze(2), src_key_padding_mask[:, video_length:].unsqueeze(1)
|
511 |
+
attn_mask = torch.matmul(qmask.float(), kmask.float()).bool().repeat(self.nhead, 1, 1)
|
512 |
+
|
513 |
+
# - key_padding_mask: :math:`(S)` or :math:`(N, S)` where N is the batch size, S is the source sequence length.
|
514 |
+
# If a FloatTensor is provided, it will be directly added to the value.
|
515 |
+
# If a BoolTensor is provided, the positions with the
|
516 |
+
# value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
|
517 |
+
# - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
|
518 |
+
# 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
|
519 |
+
# S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked
|
520 |
+
# positions. If a BoolTensor is provided, positions with ``True``
|
521 |
+
# are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
|
522 |
+
# is provided, it will be added to the attention weight.
|
523 |
+
# print(q.shape, k.shape, v.shape, attn_mask.shape, src_key_padding_mask[:, video_length + 1:].shape)
|
524 |
+
|
525 |
+
# import pdb; pdb.set_trace()
|
526 |
+
src2, attn_weights = self.self_attn(q, k, v, attn_mask=attn_mask, key_padding_mask=src_key_padding_mask[:, video_length:], dummy=dummy)
|
527 |
+
|
528 |
+
src2 = src[:video_length] + self.dropout1(src2)
|
529 |
+
src3 = self.norm1(src2)
|
530 |
+
src3 = self.linear2(self.dropout(self.activation(self.linear1(src3))))
|
531 |
+
src2 = src2 + self.dropout2(src3)
|
532 |
+
src2 = self.norm2(src2)
|
533 |
+
|
534 |
+
src = torch.cat([src2, src[video_length:]])
|
535 |
+
return src, attn_weights
|
536 |
+
|
537 |
+
def forward_pre(self, src,
|
538 |
+
src_mask: Optional[Tensor] = None,
|
539 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
540 |
+
pos: Optional[Tensor] = None, dummy=True):
|
541 |
+
pass
|
542 |
+
|
543 |
+
|
544 |
+
def forward(self, src,
|
545 |
+
src_mask: Optional[Tensor] = None,
|
546 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
547 |
+
pos: Optional[Tensor] = None, dummy=True,
|
548 |
+
**kwargs):
|
549 |
+
if self.normalize_before:
|
550 |
+
return self.forward_pre(src, src_mask, src_key_padding_mask, pos, dummy=dummy)
|
551 |
+
return self.forward_post(src, src_mask, src_key_padding_mask, pos, dummy=dummy, **kwargs)
|
552 |
+
|
553 |
+
class TransformerEncoderLayer(nn.Module):
|
554 |
+
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
|
555 |
+
activation="relu", normalize_before=False):
|
556 |
+
super().__init__()
|
557 |
+
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
558 |
+
# Implementation of Feedforward model
|
559 |
+
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
560 |
+
self.dropout = nn.Dropout(dropout)
|
561 |
+
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
562 |
+
|
563 |
+
self.norm1 = nn.LayerNorm(d_model)
|
564 |
+
self.norm2 = nn.LayerNorm(d_model)
|
565 |
+
self.dropout1 = DropPath(dropout)
|
566 |
+
self.dropout2 = DropPath(dropout)
|
567 |
+
|
568 |
+
self.activation = _get_activation_fn(activation)
|
569 |
+
self.normalize_before = normalize_before
|
570 |
+
|
571 |
+
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
572 |
+
return tensor if pos is None else tensor + pos
|
573 |
+
|
574 |
+
def forward_post(self,
|
575 |
+
src,
|
576 |
+
src_mask: Optional[Tensor] = None,
|
577 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
578 |
+
pos: Optional[Tensor] = None):
|
579 |
+
q = k = self.with_pos_embed(src, pos)
|
580 |
+
src2 = self.self_attn(q, k, value=src, attn_mask=src_mask,
|
581 |
+
key_padding_mask=src_key_padding_mask)[0]
|
582 |
+
src = src + self.dropout1(src2)
|
583 |
+
src = self.norm1(src)
|
584 |
+
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
|
585 |
+
src = src + self.dropout2(src2)
|
586 |
+
src = self.norm2(src)
|
587 |
+
return src
|
588 |
+
|
589 |
+
def forward_pre(self, src,
|
590 |
+
src_mask: Optional[Tensor] = None,
|
591 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
592 |
+
pos: Optional[Tensor] = None):
|
593 |
+
pass
|
594 |
+
|
595 |
+
def forward(self, src,
|
596 |
+
src_mask: Optional[Tensor] = None,
|
597 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
598 |
+
pos: Optional[Tensor] = None):
|
599 |
+
if self.normalize_before:
|
600 |
+
return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
|
601 |
+
return self.forward_post(src, src_mask, src_key_padding_mask, pos)
|
602 |
+
|
603 |
+
|
604 |
+
class TransformerDecoderLayer(nn.Module):
|
605 |
+
|
606 |
+
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
|
607 |
+
activation="relu", normalize_before=False, keep_query_pos=False,
|
608 |
+
rm_self_attn_decoder=False):
|
609 |
+
super().__init__()
|
610 |
+
# Decoder Self-Attention
|
611 |
+
if not rm_self_attn_decoder:
|
612 |
+
self.sa_qcontent_proj = nn.Linear(d_model, d_model)
|
613 |
+
self.sa_qpos_proj = nn.Linear(d_model, d_model)
|
614 |
+
self.sa_kcontent_proj = nn.Linear(d_model, d_model)
|
615 |
+
self.sa_kpos_proj = nn.Linear(d_model, d_model)
|
616 |
+
self.sa_v_proj = nn.Linear(d_model, d_model)
|
617 |
+
self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, vdim=d_model)
|
618 |
+
|
619 |
+
self.norm1 = nn.LayerNorm(d_model)
|
620 |
+
self.dropout1 = DropPath(dropout)
|
621 |
+
|
622 |
+
# Decoder Cross-Attention
|
623 |
+
self.ca_qcontent_proj = nn.Linear(d_model, d_model)
|
624 |
+
self.ca_qpos_proj = nn.Linear(d_model, d_model)
|
625 |
+
self.ca_kcontent_proj = nn.Linear(d_model, d_model)
|
626 |
+
self.ca_kpos_proj = nn.Linear(d_model, d_model)
|
627 |
+
self.ca_v_proj = nn.Linear(d_model, d_model)
|
628 |
+
self.ca_qpos_sine_proj = nn.Linear(d_model, d_model)
|
629 |
+
self.cross_attn = MultiheadAttention(d_model * 2, nhead, dropout=dropout, vdim=d_model)
|
630 |
+
|
631 |
+
self.nhead = nhead
|
632 |
+
self.rm_self_attn_decoder = rm_self_attn_decoder
|
633 |
+
|
634 |
+
# Implementation of Feedforward model
|
635 |
+
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
636 |
+
self.dropout = nn.Dropout(dropout)
|
637 |
+
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
638 |
+
|
639 |
+
self.norm2 = nn.LayerNorm(d_model)
|
640 |
+
self.norm3 = nn.LayerNorm(d_model)
|
641 |
+
self.dropout2 = DropPath(dropout)
|
642 |
+
self.dropout3 = DropPath(dropout)
|
643 |
+
|
644 |
+
self.activation = _get_activation_fn(activation)
|
645 |
+
self.normalize_before = normalize_before
|
646 |
+
self.keep_query_pos = keep_query_pos
|
647 |
+
|
648 |
+
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
649 |
+
return tensor if pos is None else tensor + pos
|
650 |
+
|
651 |
+
def forward(self, tgt, memory,
|
652 |
+
tgt_mask: Optional[Tensor] = None,
|
653 |
+
memory_mask: Optional[Tensor] = None,
|
654 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
655 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
656 |
+
pos: Optional[Tensor] = None,
|
657 |
+
query_pos: Optional[Tensor] = None,
|
658 |
+
query_sine_embed=None,
|
659 |
+
is_first=False):
|
660 |
+
|
661 |
+
# ========== Begin of Self-Attention =============
|
662 |
+
if not self.rm_self_attn_decoder:
|
663 |
+
# Apply projections here
|
664 |
+
# shape: num_queries x batch_size x 256
|
665 |
+
q_content = self.sa_qcontent_proj(tgt) # target is the input of the first decoder layer. zero by default.
|
666 |
+
q_pos = self.sa_qpos_proj(query_pos)
|
667 |
+
k_content = self.sa_kcontent_proj(tgt)
|
668 |
+
k_pos = self.sa_kpos_proj(query_pos)
|
669 |
+
v = self.sa_v_proj(tgt)
|
670 |
+
|
671 |
+
num_queries, bs, n_model = q_content.shape
|
672 |
+
hw, _, _ = k_content.shape
|
673 |
+
|
674 |
+
q = q_content + q_pos
|
675 |
+
k = k_content + k_pos
|
676 |
+
|
677 |
+
tgt2 = self.self_attn(q, k, value=v, attn_mask=tgt_mask,
|
678 |
+
key_padding_mask=tgt_key_padding_mask)[0]
|
679 |
+
# ========== End of Self-Attention =============
|
680 |
+
|
681 |
+
tgt = tgt + self.dropout1(tgt2)
|
682 |
+
tgt = self.norm1(tgt)
|
683 |
+
|
684 |
+
# ========== Begin of Cross-Attention =============
|
685 |
+
# Apply projections here
|
686 |
+
# shape: num_queries x batch_size x 256
|
687 |
+
q_content = self.ca_qcontent_proj(tgt)
|
688 |
+
k_content = self.ca_kcontent_proj(memory)
|
689 |
+
v = self.ca_v_proj(memory)
|
690 |
+
|
691 |
+
num_queries, bs, n_model = q_content.shape
|
692 |
+
hw, _, _ = k_content.shape
|
693 |
+
|
694 |
+
k_pos = self.ca_kpos_proj(pos)
|
695 |
+
|
696 |
+
# For the first decoder layer, we concatenate the positional embedding predicted from
|
697 |
+
# the object query (the positional embedding) into the original query (key) in DETR.
|
698 |
+
if is_first or self.keep_query_pos:
|
699 |
+
q_pos = self.ca_qpos_proj(query_pos)
|
700 |
+
q = q_content + q_pos
|
701 |
+
k = k_content + k_pos
|
702 |
+
else:
|
703 |
+
q = q_content
|
704 |
+
k = k_content
|
705 |
+
|
706 |
+
q = q.view(num_queries, bs, self.nhead, n_model // self.nhead)
|
707 |
+
query_sine_embed = self.ca_qpos_sine_proj(query_sine_embed)
|
708 |
+
query_sine_embed = query_sine_embed.view(num_queries, bs, self.nhead, n_model // self.nhead)
|
709 |
+
q = torch.cat([q, query_sine_embed], dim=3).view(num_queries, bs, n_model * 2)
|
710 |
+
k = k.view(hw, bs, self.nhead, n_model // self.nhead)
|
711 |
+
k_pos = k_pos.view(hw, bs, self.nhead, n_model // self.nhead)
|
712 |
+
k = torch.cat([k, k_pos], dim=3).view(hw, bs, n_model * 2)
|
713 |
+
|
714 |
+
tgt2 = self.cross_attn(query=q,
|
715 |
+
key=k,
|
716 |
+
value=v, attn_mask=memory_mask,
|
717 |
+
key_padding_mask=memory_key_padding_mask)[0]
|
718 |
+
# ========== End of Cross-Attention =============
|
719 |
+
|
720 |
+
tgt = tgt + self.dropout2(tgt2)
|
721 |
+
tgt = self.norm2(tgt)
|
722 |
+
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
|
723 |
+
tgt = tgt + self.dropout3(tgt2)
|
724 |
+
tgt = self.norm3(tgt)
|
725 |
+
return tgt
|
726 |
+
|
727 |
+
|
728 |
+
class TransformerDecoderLayerThin(nn.Module):
|
729 |
+
"""removed intermediate layer"""
|
730 |
+
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
|
731 |
+
activation="relu", normalize_before=False):
|
732 |
+
super().__init__()
|
733 |
+
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
734 |
+
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
735 |
+
# Implementation of Feedforward model
|
736 |
+
self.linear1 = nn.Linear(d_model, d_model)
|
737 |
+
|
738 |
+
|
739 |
+
self.norm1 = nn.LayerNorm(d_model)
|
740 |
+
self.norm2 = nn.LayerNorm(d_model)
|
741 |
+
# self.norm3 = nn.LayerNorm(d_model)
|
742 |
+
self.dropout1 = DropPath(dropout)
|
743 |
+
self.dropout2 = DropPath(dropout)
|
744 |
+
|
745 |
+
|
746 |
+
# self.activation = _get_activation_fn(activation)
|
747 |
+
self.normalize_before = normalize_before
|
748 |
+
|
749 |
+
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
750 |
+
return tensor if pos is None else tensor + pos
|
751 |
+
|
752 |
+
def forward_post(self, tgt, memory,
|
753 |
+
tgt_mask: Optional[Tensor] = None,
|
754 |
+
memory_mask: Optional[Tensor] = None,
|
755 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
756 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
757 |
+
pos: Optional[Tensor] = None,
|
758 |
+
query_pos: Optional[Tensor] = None):
|
759 |
+
q = k = self.with_pos_embed(tgt, query_pos)
|
760 |
+
tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,
|
761 |
+
key_padding_mask=tgt_key_padding_mask)[0]
|
762 |
+
tgt = tgt + self.dropout1(tgt2)
|
763 |
+
tgt = self.norm1(tgt)
|
764 |
+
tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
|
765 |
+
key=self.with_pos_embed(memory, pos),
|
766 |
+
value=memory, attn_mask=memory_mask,
|
767 |
+
key_padding_mask=memory_key_padding_mask)[0]
|
768 |
+
tgt2 = self.linear1(tgt2)
|
769 |
+
tgt = tgt + self.dropout2(tgt2)
|
770 |
+
tgt = self.norm2(tgt)
|
771 |
+
return tgt
|
772 |
+
|
773 |
+
def forward_pre(self, tgt, memory,
|
774 |
+
tgt_mask: Optional[Tensor] = None,
|
775 |
+
memory_mask: Optional[Tensor] = None,
|
776 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
777 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
778 |
+
pos: Optional[Tensor] = None,
|
779 |
+
query_pos: Optional[Tensor] = None):
|
780 |
+
tgt2 = self.norm1(tgt)
|
781 |
+
q = k = self.with_pos_embed(tgt2, query_pos)
|
782 |
+
tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
|
783 |
+
key_padding_mask=tgt_key_padding_mask)[0]
|
784 |
+
tgt = tgt + self.dropout1(tgt2)
|
785 |
+
tgt2 = self.norm2(tgt)
|
786 |
+
tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
|
787 |
+
key=self.with_pos_embed(memory, pos),
|
788 |
+
value=memory, attn_mask=memory_mask,
|
789 |
+
key_padding_mask=memory_key_padding_mask)[0]
|
790 |
+
tgt = tgt + self.dropout2(tgt2)
|
791 |
+
tgt2 = self.norm3(tgt)
|
792 |
+
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
|
793 |
+
tgt = tgt + self.dropout3(tgt2)
|
794 |
+
return tgt
|
795 |
+
|
796 |
+
def forward(self, tgt, memory,
|
797 |
+
tgt_mask: Optional[Tensor] = None,
|
798 |
+
memory_mask: Optional[Tensor] = None,
|
799 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
800 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
801 |
+
pos: Optional[Tensor] = None,
|
802 |
+
query_pos: Optional[Tensor] = None):
|
803 |
+
if self.normalize_before:
|
804 |
+
return self.forward_pre(tgt, memory, tgt_mask, memory_mask,
|
805 |
+
tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
|
806 |
+
return self.forward_post(tgt, memory, tgt_mask, memory_mask,
|
807 |
+
tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
|
808 |
+
|
809 |
+
|
810 |
+
|
811 |
+
def _get_clones(module, N):
|
812 |
+
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
|
813 |
+
|
814 |
+
|
815 |
+
def build_transformer(args):
|
816 |
+
return Transformer(
|
817 |
+
d_model=args.hidden_dim,
|
818 |
+
dropout=args.dropout,
|
819 |
+
nhead=args.nheads,
|
820 |
+
dim_feedforward=args.dim_feedforward,
|
821 |
+
num_encoder_layers=args.enc_layers,
|
822 |
+
num_decoder_layers=args.dec_layers,
|
823 |
+
normalize_before=args.pre_norm,
|
824 |
+
return_intermediate_dec=True,
|
825 |
+
activation='prelu',
|
826 |
+
args=args
|
827 |
+
)
|
828 |
+
|
829 |
+
def drop_path(x, drop_prob=0.0, training=False):
|
830 |
+
"""
|
831 |
+
Stochastic Depth per sample.
|
832 |
+
"""
|
833 |
+
if drop_prob == 0.0 or not training:
|
834 |
+
return x
|
835 |
+
|
836 |
+
keep_prob = 1 - drop_prob
|
837 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
|
838 |
+
mask = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
|
839 |
+
mask.floor_()
|
840 |
+
x = x.div(keep_prob) * mask
|
841 |
+
|
842 |
+
return x
|
843 |
+
|
844 |
+
class DropPath(nn.Module):
|
845 |
+
"""
|
846 |
+
Drop paths per sample (when applied in main path of residual blocks).
|
847 |
+
"""
|
848 |
+
|
849 |
+
def __init__(self, drop_prob=None):
|
850 |
+
super(DropPath, self).__init__()
|
851 |
+
|
852 |
+
self.drop_prob = drop_prob
|
853 |
+
|
854 |
+
def forward(self, x):
|
855 |
+
x = x.permute(1, 0, 2)
|
856 |
+
res = drop_path(x, self.drop_prob, self.training)
|
857 |
+
return res.permute(1, 0, 2)
|
858 |
+
|
859 |
+
def _get_activation_fn(activation):
|
860 |
+
"""Return an activation function given a string"""
|
861 |
+
if activation == "relu":
|
862 |
+
return F.relu
|
863 |
+
if activation == "gelu":
|
864 |
+
return F.gelu
|
865 |
+
if activation == "glu":
|
866 |
+
return F.glu
|
867 |
+
if activation == "prelu":
|
868 |
+
return nn.PReLU()
|
869 |
+
if activation == "selu":
|
870 |
+
return F.selu
|
871 |
+
raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
|
third_party/cgdetr/data/LICENSE
ADDED
@@ -0,0 +1,437 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Attribution-NonCommercial-ShareAlike 4.0 International
|
2 |
+
|
3 |
+
=======================================================================
|
4 |
+
|
5 |
+
Creative Commons Corporation ("Creative Commons") is not a law firm and
|
6 |
+
does not provide legal services or legal advice. Distribution of
|
7 |
+
Creative Commons public licenses does not create a lawyer-client or
|
8 |
+
other relationship. Creative Commons makes its licenses and related
|
9 |
+
information available on an "as-is" basis. Creative Commons gives no
|
10 |
+
warranties regarding its licenses, any material licensed under their
|
11 |
+
terms and conditions, or any related information. Creative Commons
|
12 |
+
disclaims all liability for damages resulting from their use to the
|
13 |
+
fullest extent possible.
|
14 |
+
|
15 |
+
Using Creative Commons Public Licenses
|
16 |
+
|
17 |
+
Creative Commons public licenses provide a standard set of terms and
|
18 |
+
conditions that creators and other rights holders may use to share
|
19 |
+
original works of authorship and other material subject to copyright
|
20 |
+
and certain other rights specified in the public license below. The
|
21 |
+
following considerations are for informational purposes only, are not
|
22 |
+
exhaustive, and do not form part of our licenses.
|
23 |
+
|
24 |
+
Considerations for licensors: Our public licenses are
|
25 |
+
intended for use by those authorized to give the public
|
26 |
+
permission to use material in ways otherwise restricted by
|
27 |
+
copyright and certain other rights. Our licenses are
|
28 |
+
irrevocable. Licensors should read and understand the terms
|
29 |
+
and conditions of the license they choose before applying it.
|
30 |
+
Licensors should also secure all rights necessary before
|
31 |
+
applying our licenses so that the public can reuse the
|
32 |
+
material as expected. Licensors should clearly mark any
|
33 |
+
material not subject to the license. This includes other CC-
|
34 |
+
licensed material, or material used under an exception or
|
35 |
+
limitation to copyright. More considerations for licensors:
|
36 |
+
wiki.creativecommons.org/Considerations_for_licensors
|
37 |
+
|
38 |
+
Considerations for the public: By using one of our public
|
39 |
+
licenses, a licensor grants the public permission to use the
|
40 |
+
licensed material under specified terms and conditions. If
|
41 |
+
the licensor's permission is not necessary for any reason--for
|
42 |
+
example, because of any applicable exception or limitation to
|
43 |
+
copyright--then that use is not regulated by the license. Our
|
44 |
+
licenses grant only permissions under copyright and certain
|
45 |
+
other rights that a licensor has authority to grant. Use of
|
46 |
+
the licensed material may still be restricted for other
|
47 |
+
reasons, including because others have copyright or other
|
48 |
+
rights in the material. A licensor may make special requests,
|
49 |
+
such as asking that all changes be marked or described.
|
50 |
+
Although not required by our licenses, you are encouraged to
|
51 |
+
respect those requests where reasonable. More considerations
|
52 |
+
for the public:
|
53 |
+
wiki.creativecommons.org/Considerations_for_licensees
|
54 |
+
|
55 |
+
=======================================================================
|
56 |
+
|
57 |
+
Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International
|
58 |
+
Public License
|
59 |
+
|
60 |
+
By exercising the Licensed Rights (defined below), You accept and agree
|
61 |
+
to be bound by the terms and conditions of this Creative Commons
|
62 |
+
Attribution-NonCommercial-ShareAlike 4.0 International Public License
|
63 |
+
("Public License"). To the extent this Public License may be
|
64 |
+
interpreted as a contract, You are granted the Licensed Rights in
|
65 |
+
consideration of Your acceptance of these terms and conditions, and the
|
66 |
+
Licensor grants You such rights in consideration of benefits the
|
67 |
+
Licensor receives from making the Licensed Material available under
|
68 |
+
these terms and conditions.
|
69 |
+
|
70 |
+
|
71 |
+
Section 1 -- Definitions.
|
72 |
+
|
73 |
+
a. Adapted Material means material subject to Copyright and Similar
|
74 |
+
Rights that is derived from or based upon the Licensed Material
|
75 |
+
and in which the Licensed Material is translated, altered,
|
76 |
+
arranged, transformed, or otherwise modified in a manner requiring
|
77 |
+
permission under the Copyright and Similar Rights held by the
|
78 |
+
Licensor. For purposes of this Public License, where the Licensed
|
79 |
+
Material is a musical work, performance, or sound recording,
|
80 |
+
Adapted Material is always produced where the Licensed Material is
|
81 |
+
synched in timed relation with a moving image.
|
82 |
+
|
83 |
+
b. Adapter's License means the license You apply to Your Copyright
|
84 |
+
and Similar Rights in Your contributions to Adapted Material in
|
85 |
+
accordance with the terms and conditions of this Public License.
|
86 |
+
|
87 |
+
c. BY-NC-SA Compatible License means a license listed at
|
88 |
+
creativecommons.org/compatiblelicenses, approved by Creative
|
89 |
+
Commons as essentially the equivalent of this Public License.
|
90 |
+
|
91 |
+
d. Copyright and Similar Rights means copyright and/or similar rights
|
92 |
+
closely related to copyright including, without limitation,
|
93 |
+
performance, broadcast, sound recording, and Sui Generis Database
|
94 |
+
Rights, without regard to how the rights are labeled or
|
95 |
+
categorized. For purposes of this Public License, the rights
|
96 |
+
specified in Section 2(b)(1)-(2) are not Copyright and Similar
|
97 |
+
Rights.
|
98 |
+
|
99 |
+
e. Effective Technological Measures means those measures that, in the
|
100 |
+
absence of proper authority, may not be circumvented under laws
|
101 |
+
fulfilling obligations under Article 11 of the WIPO Copyright
|
102 |
+
Treaty adopted on December 20, 1996, and/or similar international
|
103 |
+
agreements.
|
104 |
+
|
105 |
+
f. Exceptions and Limitations means fair use, fair dealing, and/or
|
106 |
+
any other exception or limitation to Copyright and Similar Rights
|
107 |
+
that applies to Your use of the Licensed Material.
|
108 |
+
|
109 |
+
g. License Elements means the license attributes listed in the name
|
110 |
+
of a Creative Commons Public License. The License Elements of this
|
111 |
+
Public License are Attribution, NonCommercial, and ShareAlike.
|
112 |
+
|
113 |
+
h. Licensed Material means the artistic or literary work, database,
|
114 |
+
or other material to which the Licensor applied this Public
|
115 |
+
License.
|
116 |
+
|
117 |
+
i. Licensed Rights means the rights granted to You subject to the
|
118 |
+
terms and conditions of this Public License, which are limited to
|
119 |
+
all Copyright and Similar Rights that apply to Your use of the
|
120 |
+
Licensed Material and that the Licensor has authority to license.
|
121 |
+
|
122 |
+
j. Licensor means the individual(s) or entity(ies) granting rights
|
123 |
+
under this Public License.
|
124 |
+
|
125 |
+
k. NonCommercial means not primarily intended for or directed towards
|
126 |
+
commercial advantage or monetary compensation. For purposes of
|
127 |
+
this Public License, the exchange of the Licensed Material for
|
128 |
+
other material subject to Copyright and Similar Rights by digital
|
129 |
+
file-sharing or similar means is NonCommercial provided there is
|
130 |
+
no payment of monetary compensation in connection with the
|
131 |
+
exchange.
|
132 |
+
|
133 |
+
l. Share means to provide material to the public by any means or
|
134 |
+
process that requires permission under the Licensed Rights, such
|
135 |
+
as reproduction, public display, public performance, distribution,
|
136 |
+
dissemination, communication, or importation, and to make material
|
137 |
+
available to the public including in ways that members of the
|
138 |
+
public may access the material from a place and at a time
|
139 |
+
individually chosen by them.
|
140 |
+
|
141 |
+
m. Sui Generis Database Rights means rights other than copyright
|
142 |
+
resulting from Directive 96/9/EC of the European Parliament and of
|
143 |
+
the Council of 11 March 1996 on the legal protection of databases,
|
144 |
+
as amended and/or succeeded, as well as other essentially
|
145 |
+
equivalent rights anywhere in the world.
|
146 |
+
|
147 |
+
n. You means the individual or entity exercising the Licensed Rights
|
148 |
+
under this Public License. Your has a corresponding meaning.
|
149 |
+
|
150 |
+
|
151 |
+
Section 2 -- Scope.
|
152 |
+
|
153 |
+
a. License grant.
|
154 |
+
|
155 |
+
1. Subject to the terms and conditions of this Public License,
|
156 |
+
the Licensor hereby grants You a worldwide, royalty-free,
|
157 |
+
non-sublicensable, non-exclusive, irrevocable license to
|
158 |
+
exercise the Licensed Rights in the Licensed Material to:
|
159 |
+
|
160 |
+
a. reproduce and Share the Licensed Material, in whole or
|
161 |
+
in part, for NonCommercial purposes only; and
|
162 |
+
|
163 |
+
b. produce, reproduce, and Share Adapted Material for
|
164 |
+
NonCommercial purposes only.
|
165 |
+
|
166 |
+
2. Exceptions and Limitations. For the avoidance of doubt, where
|
167 |
+
Exceptions and Limitations apply to Your use, this Public
|
168 |
+
License does not apply, and You do not need to comply with
|
169 |
+
its terms and conditions.
|
170 |
+
|
171 |
+
3. Term. The term of this Public License is specified in Section
|
172 |
+
6(a).
|
173 |
+
|
174 |
+
4. Media and formats; technical modifications allowed. The
|
175 |
+
Licensor authorizes You to exercise the Licensed Rights in
|
176 |
+
all media and formats whether now known or hereafter created,
|
177 |
+
and to make technical modifications necessary to do so. The
|
178 |
+
Licensor waives and/or agrees not to assert any right or
|
179 |
+
authority to forbid You from making technical modifications
|
180 |
+
necessary to exercise the Licensed Rights, including
|
181 |
+
technical modifications necessary to circumvent Effective
|
182 |
+
Technological Measures. For purposes of this Public License,
|
183 |
+
simply making modifications authorized by this Section 2(a)
|
184 |
+
(4) never produces Adapted Material.
|
185 |
+
|
186 |
+
5. Downstream recipients.
|
187 |
+
|
188 |
+
a. Offer from the Licensor -- Licensed Material. Every
|
189 |
+
recipient of the Licensed Material automatically
|
190 |
+
receives an offer from the Licensor to exercise the
|
191 |
+
Licensed Rights under the terms and conditions of this
|
192 |
+
Public License.
|
193 |
+
|
194 |
+
b. Additional offer from the Licensor -- Adapted Material.
|
195 |
+
Every recipient of Adapted Material from You
|
196 |
+
automatically receives an offer from the Licensor to
|
197 |
+
exercise the Licensed Rights in the Adapted Material
|
198 |
+
under the conditions of the Adapter's License You apply.
|
199 |
+
|
200 |
+
c. No downstream restrictions. You may not offer or impose
|
201 |
+
any additional or different terms or conditions on, or
|
202 |
+
apply any Effective Technological Measures to, the
|
203 |
+
Licensed Material if doing so restricts exercise of the
|
204 |
+
Licensed Rights by any recipient of the Licensed
|
205 |
+
Material.
|
206 |
+
|
207 |
+
6. No endorsement. Nothing in this Public License constitutes or
|
208 |
+
may be construed as permission to assert or imply that You
|
209 |
+
are, or that Your use of the Licensed Material is, connected
|
210 |
+
with, or sponsored, endorsed, or granted official status by,
|
211 |
+
the Licensor or others designated to receive attribution as
|
212 |
+
provided in Section 3(a)(1)(A)(i).
|
213 |
+
|
214 |
+
b. Other rights.
|
215 |
+
|
216 |
+
1. Moral rights, such as the right of integrity, are not
|
217 |
+
licensed under this Public License, nor are publicity,
|
218 |
+
privacy, and/or other similar personality rights; however, to
|
219 |
+
the extent possible, the Licensor waives and/or agrees not to
|
220 |
+
assert any such rights held by the Licensor to the limited
|
221 |
+
extent necessary to allow You to exercise the Licensed
|
222 |
+
Rights, but not otherwise.
|
223 |
+
|
224 |
+
2. Patent and trademark rights are not licensed under this
|
225 |
+
Public License.
|
226 |
+
|
227 |
+
3. To the extent possible, the Licensor waives any right to
|
228 |
+
collect royalties from You for the exercise of the Licensed
|
229 |
+
Rights, whether directly or through a collecting society
|
230 |
+
under any voluntary or waivable statutory or compulsory
|
231 |
+
licensing scheme. In all other cases the Licensor expressly
|
232 |
+
reserves any right to collect such royalties, including when
|
233 |
+
the Licensed Material is used other than for NonCommercial
|
234 |
+
purposes.
|
235 |
+
|
236 |
+
|
237 |
+
Section 3 -- License Conditions.
|
238 |
+
|
239 |
+
Your exercise of the Licensed Rights is expressly made subject to the
|
240 |
+
following conditions.
|
241 |
+
|
242 |
+
a. Attribution.
|
243 |
+
|
244 |
+
1. If You Share the Licensed Material (including in modified
|
245 |
+
form), You must:
|
246 |
+
|
247 |
+
a. retain the following if it is supplied by the Licensor
|
248 |
+
with the Licensed Material:
|
249 |
+
|
250 |
+
i. identification of the creator(s) of the Licensed
|
251 |
+
Material and any others designated to receive
|
252 |
+
attribution, in any reasonable manner requested by
|
253 |
+
the Licensor (including by pseudonym if
|
254 |
+
designated);
|
255 |
+
|
256 |
+
ii. a copyright notice;
|
257 |
+
|
258 |
+
iii. a notice that refers to this Public License;
|
259 |
+
|
260 |
+
iv. a notice that refers to the disclaimer of
|
261 |
+
warranties;
|
262 |
+
|
263 |
+
v. a URI or hyperlink to the Licensed Material to the
|
264 |
+
extent reasonably practicable;
|
265 |
+
|
266 |
+
b. indicate if You modified the Licensed Material and
|
267 |
+
retain an indication of any previous modifications; and
|
268 |
+
|
269 |
+
c. indicate the Licensed Material is licensed under this
|
270 |
+
Public License, and include the text of, or the URI or
|
271 |
+
hyperlink to, this Public License.
|
272 |
+
|
273 |
+
2. You may satisfy the conditions in Section 3(a)(1) in any
|
274 |
+
reasonable manner based on the medium, means, and context in
|
275 |
+
which You Share the Licensed Material. For example, it may be
|
276 |
+
reasonable to satisfy the conditions by providing a URI or
|
277 |
+
hyperlink to a resource that includes the required
|
278 |
+
information.
|
279 |
+
3. If requested by the Licensor, You must remove any of the
|
280 |
+
information required by Section 3(a)(1)(A) to the extent
|
281 |
+
reasonably practicable.
|
282 |
+
|
283 |
+
b. ShareAlike.
|
284 |
+
|
285 |
+
In addition to the conditions in Section 3(a), if You Share
|
286 |
+
Adapted Material You produce, the following conditions also apply.
|
287 |
+
|
288 |
+
1. The Adapter's License You apply must be a Creative Commons
|
289 |
+
license with the same License Elements, this version or
|
290 |
+
later, or a BY-NC-SA Compatible License.
|
291 |
+
|
292 |
+
2. You must include the text of, or the URI or hyperlink to, the
|
293 |
+
Adapter's License You apply. You may satisfy this condition
|
294 |
+
in any reasonable manner based on the medium, means, and
|
295 |
+
context in which You Share Adapted Material.
|
296 |
+
|
297 |
+
3. You may not offer or impose any additional or different terms
|
298 |
+
or conditions on, or apply any Effective Technological
|
299 |
+
Measures to, Adapted Material that restrict exercise of the
|
300 |
+
rights granted under the Adapter's License You apply.
|
301 |
+
|
302 |
+
|
303 |
+
Section 4 -- Sui Generis Database Rights.
|
304 |
+
|
305 |
+
Where the Licensed Rights include Sui Generis Database Rights that
|
306 |
+
apply to Your use of the Licensed Material:
|
307 |
+
|
308 |
+
a. for the avoidance of doubt, Section 2(a)(1) grants You the right
|
309 |
+
to extract, reuse, reproduce, and Share all or a substantial
|
310 |
+
portion of the contents of the database for NonCommercial purposes
|
311 |
+
only;
|
312 |
+
|
313 |
+
b. if You include all or a substantial portion of the database
|
314 |
+
contents in a database in which You have Sui Generis Database
|
315 |
+
Rights, then the database in which You have Sui Generis Database
|
316 |
+
Rights (but not its individual contents) is Adapted Material,
|
317 |
+
including for purposes of Section 3(b); and
|
318 |
+
|
319 |
+
c. You must comply with the conditions in Section 3(a) if You Share
|
320 |
+
all or a substantial portion of the contents of the database.
|
321 |
+
|
322 |
+
For the avoidance of doubt, this Section 4 supplements and does not
|
323 |
+
replace Your obligations under this Public License where the Licensed
|
324 |
+
Rights include other Copyright and Similar Rights.
|
325 |
+
|
326 |
+
|
327 |
+
Section 5 -- Disclaimer of Warranties and Limitation of Liability.
|
328 |
+
|
329 |
+
a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
|
330 |
+
EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
|
331 |
+
AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
|
332 |
+
ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
|
333 |
+
IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
|
334 |
+
WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
|
335 |
+
PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
|
336 |
+
ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
|
337 |
+
KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
|
338 |
+
ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
|
339 |
+
|
340 |
+
b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
|
341 |
+
TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
|
342 |
+
NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
|
343 |
+
INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
|
344 |
+
COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
|
345 |
+
USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
|
346 |
+
ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
|
347 |
+
DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
|
348 |
+
IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
|
349 |
+
|
350 |
+
c. The disclaimer of warranties and limitation of liability provided
|
351 |
+
above shall be interpreted in a manner that, to the extent
|
352 |
+
possible, most closely approximates an absolute disclaimer and
|
353 |
+
waiver of all liability.
|
354 |
+
|
355 |
+
|
356 |
+
Section 6 -- Term and Termination.
|
357 |
+
|
358 |
+
a. This Public License applies for the term of the Copyright and
|
359 |
+
Similar Rights licensed here. However, if You fail to comply with
|
360 |
+
this Public License, then Your rights under this Public License
|
361 |
+
terminate automatically.
|
362 |
+
|
363 |
+
b. Where Your right to use the Licensed Material has terminated under
|
364 |
+
Section 6(a), it reinstates:
|
365 |
+
|
366 |
+
1. automatically as of the date the violation is cured, provided
|
367 |
+
it is cured within 30 days of Your discovery of the
|
368 |
+
violation; or
|
369 |
+
|
370 |
+
2. upon express reinstatement by the Licensor.
|
371 |
+
|
372 |
+
For the avoidance of doubt, this Section 6(b) does not affect any
|
373 |
+
right the Licensor may have to seek remedies for Your violations
|
374 |
+
of this Public License.
|
375 |
+
|
376 |
+
c. For the avoidance of doubt, the Licensor may also offer the
|
377 |
+
Licensed Material under separate terms or conditions or stop
|
378 |
+
distributing the Licensed Material at any time; however, doing so
|
379 |
+
will not terminate this Public License.
|
380 |
+
|
381 |
+
d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
|
382 |
+
License.
|
383 |
+
|
384 |
+
|
385 |
+
Section 7 -- Other Terms and Conditions.
|
386 |
+
|
387 |
+
a. The Licensor shall not be bound by any additional or different
|
388 |
+
terms or conditions communicated by You unless expressly agreed.
|
389 |
+
|
390 |
+
b. Any arrangements, understandings, or agreements regarding the
|
391 |
+
Licensed Material not stated herein are separate from and
|
392 |
+
independent of the terms and conditions of this Public License.
|
393 |
+
|
394 |
+
|
395 |
+
Section 8 -- Interpretation.
|
396 |
+
|
397 |
+
a. For the avoidance of doubt, this Public License does not, and
|
398 |
+
shall not be interpreted to, reduce, limit, restrict, or impose
|
399 |
+
conditions on any use of the Licensed Material that could lawfully
|
400 |
+
be made without permission under this Public License.
|
401 |
+
|
402 |
+
b. To the extent possible, if any provision of this Public License is
|
403 |
+
deemed unenforceable, it shall be automatically reformed to the
|
404 |
+
minimum extent necessary to make it enforceable. If the provision
|
405 |
+
cannot be reformed, it shall be severed from this Public License
|
406 |
+
without affecting the enforceability of the remaining terms and
|
407 |
+
conditions.
|
408 |
+
|
409 |
+
c. No term or condition of this Public License will be waived and no
|
410 |
+
failure to comply consented to unless expressly agreed to by the
|
411 |
+
Licensor.
|
412 |
+
|
413 |
+
d. Nothing in this Public License constitutes or may be interpreted
|
414 |
+
as a limitation upon, or waiver of, any privileges and immunities
|
415 |
+
that apply to the Licensor or You, including from the legal
|
416 |
+
processes of any jurisdiction or authority.
|
417 |
+
|
418 |
+
=======================================================================
|
419 |
+
|
420 |
+
Creative Commons is not a party to its public
|
421 |
+
licenses. Notwithstanding, Creative Commons may elect to apply one of
|
422 |
+
its public licenses to material it publishes and in those instances
|
423 |
+
will be considered the “Licensor.” The text of the Creative Commons
|
424 |
+
public licenses is dedicated to the public domain under the CC0 Public
|
425 |
+
Domain Dedication. Except for the limited purpose of indicating that
|
426 |
+
material is shared under a Creative Commons public license or as
|
427 |
+
otherwise permitted by the Creative Commons policies published at
|
428 |
+
creativecommons.org/policies, Creative Commons does not authorize the
|
429 |
+
use of the trademark "Creative Commons" or any other trademark or logo
|
430 |
+
of Creative Commons without its prior written consent including,
|
431 |
+
without limitation, in connection with any unauthorized modifications
|
432 |
+
to any of its public licenses or any other arrangements,
|
433 |
+
understandings, or agreements concerning use of licensed material. For
|
434 |
+
the avoidance of doubt, this paragraph does not form part of the
|
435 |
+
public licenses.
|
436 |
+
|
437 |
+
Creative Commons may be contacted at creativecommons.org.
|
third_party/cgdetr/data/README.md
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## QVHighlights Dataset
|
2 |
+
|
3 |
+
Our annotation files include 3 splits: `train`, `val` and `test`. Each file is in [JSON Line](https://jsonlines.org/) format, each row of the files can be loaded as a single `dict` in Python. Below is an example of the annotation:
|
4 |
+
|
5 |
+
```
|
6 |
+
{
|
7 |
+
"qid": 8737,
|
8 |
+
"query": "A family is playing basketball together on a green court outside.",
|
9 |
+
"duration": 126,
|
10 |
+
"vid": "bP5KfdFJzC4_660.0_810.0",
|
11 |
+
"relevant_windows": [[0, 16]],
|
12 |
+
"relevant_clip_ids": [0, 1, 2, 3, 4, 5, 6, 7],
|
13 |
+
"saliency_scores": [[4, 1, 1], [4, 1, 1], [4, 2, 1], [4, 3, 2], [4, 3, 2], [4, 3, 3], [4, 3, 3], [4, 3, 2]]
|
14 |
+
}
|
15 |
+
```
|
16 |
+
`qid` is a unique identifier of a `query`. This query corresponds to a video identified by its video id `vid`. The `vid` is formatted as `{youtube_id}_{start_time}_{end_time}`. Use this information, one can retrieve the YouTube video from a url `https://www.youtube.com/embed/{youtube_id}?start={start_time}&end={end_time}&version=3`. For example, the video in this example is `https://www.youtube.com/embed/bP5KfdFJzC4?start=660&end=810&version=3`.
|
17 |
+
`duration` is an integer indicating the duration of this video.
|
18 |
+
`relevant_windows` is the list of windows that localize the moments, each window has two numbers, one indicates the start time of the moment, another one indicates the end time. `relevant_clip_ids` is the list of ids to the segmented 2-second clips that fall into the moments specified by `relevant_windows`, starting from 0.
|
19 |
+
`saliency_scores` contains the saliency scores annotations, each sublist corresponds to a clip in `relevant_clip_ids`. There are 3 elements in each sublist, they are the scores from three different annotators. A score of `4` means `Very Good`, while `0` means `Very Bad`.
|
20 |
+
|
21 |
+
Note that the three fields `relevant_clip_ids`, `relevant_windows` and `saliency_scores` for `test` split is not included. Please refer to [../standalone_eval/README.md](../standalone_eval/README.md) for details on evaluating predictions on `test`.
|
22 |
+
|
23 |
+
In addition to the annotation files, we also provided the subtitle file for our weakly supervised ASR pre-training: [subs_train.jsonl](./subs_train.jsonl). This file is formatted similarly as our annotation files, but without the `saliency_scores` entry. This file is not needed if you do not plan to pretrain models using it.
|
24 |
+
|
third_party/cgdetr/standalone_eval/README.md
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
QVHighlights Evaluation and Codalab Submission
|
2 |
+
==================
|
3 |
+
|
4 |
+
### Task Definition
|
5 |
+
Given a video and a natural language query, our task requires a system to retrieve the most relevant moments in the video, and detect the highlightness of the clips in the video.
|
6 |
+
|
7 |
+
### Evaluation
|
8 |
+
At project root, run
|
9 |
+
```
|
10 |
+
bash standalone_eval/eval_sample.sh
|
11 |
+
```
|
12 |
+
This command will use [eval.py](eval.py) to evaluate the provided prediction file [sample_val_preds.jsonl](sample_val_preds.jsonl),
|
13 |
+
the output will be written into `sample_val_preds_metrics.json`.
|
14 |
+
The content in this generated file should be similar if not the same as [sample_val_preds_metrics_raw.json](sample_val_preds_metrics_raw.json) file.
|
15 |
+
|
16 |
+
### Format
|
17 |
+
|
18 |
+
The prediction file [sample_val_preds.jsonl](sample_val_preds.jsonl) is in [JSON Line](https://jsonlines.org/) format, each row of the files can be loaded as a single `dict` in Python. Below is an example of a single line in the prediction file:
|
19 |
+
```
|
20 |
+
{
|
21 |
+
"qid": 2579,
|
22 |
+
"query": "A girl and her mother cooked while talking with each other on facetime.",
|
23 |
+
"vid": "NUsG9BgSes0_210.0_360.0",
|
24 |
+
"pred_relevant_windows": [
|
25 |
+
[0, 70, 0.9986],
|
26 |
+
[78, 146, 0.4138],
|
27 |
+
[0, 146, 0.0444],
|
28 |
+
...
|
29 |
+
],
|
30 |
+
"pred_saliency_scores": [-0.2452, -0.3779, -0.4746, ...]
|
31 |
+
}
|
32 |
+
|
33 |
+
```
|
34 |
+
|
35 |
+
|
36 |
+
|
37 |
+
| entry | description |
|
38 |
+
| --- | ----|
|
39 |
+
| `qid` | `int`, unique query id |
|
40 |
+
| `query` | `str`, natural language query, not used by the evaluation script |
|
41 |
+
| `vid` | `str`, unique video id |
|
42 |
+
| `pred_relevant_windows` | `list(list)`, moment retrieval predictions. Each sublist contains 3 elements, `[start (seconds), end (seconds), score]`|
|
43 |
+
| `pred_saliency_scores` | `list(float)`, highlight prediction scores. The higher the better. This list should contain a score for each of the 2-second clip in the videos, and is ordered. |
|
44 |
+
|
45 |
+
|
46 |
+
### Codalab Submission
|
47 |
+
To test your model's performance on `test` split,
|
48 |
+
please submit both `val` and `test` predictions to our
|
49 |
+
[Codalab evaluation server](https://codalab.lisn.upsaclay.fr/competitions/6937).
|
50 |
+
The submission file should be a single `.zip ` file (no enclosing folder)
|
51 |
+
that contains the two prediction files
|
52 |
+
`hl_val_submission.jsonl` and `hl_test_submission.jsonl`, each of the `*submission.jsonl` file
|
53 |
+
should be formatted as instructed above.
|
54 |
+
|
third_party/cgdetr/standalone_eval/eval.py
ADDED
@@ -0,0 +1,361 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from collections import OrderedDict, defaultdict
|
3 |
+
import json
|
4 |
+
import time
|
5 |
+
import copy
|
6 |
+
import multiprocessing as mp
|
7 |
+
from src.model.cgdetr_main.standalone_eval.utils import compute_average_precision_detection, \
|
8 |
+
compute_temporal_iou_batch_cross, compute_temporal_iou_batch_paired, load_jsonl, get_ap
|
9 |
+
|
10 |
+
|
11 |
+
def compute_average_precision_detection_wrapper(
|
12 |
+
input_triple, tiou_thresholds=np.linspace(0.5, 0.95, 10)):
|
13 |
+
qid, ground_truth, prediction = input_triple
|
14 |
+
scores = compute_average_precision_detection(
|
15 |
+
ground_truth, prediction, tiou_thresholds=tiou_thresholds)
|
16 |
+
return qid, scores
|
17 |
+
|
18 |
+
|
19 |
+
def compute_mr_ap(submission, ground_truth, iou_thds=np.linspace(0.5, 0.95, 10),
|
20 |
+
max_gt_windows=None, max_pred_windows=10, num_workers=8, chunksize=50):
|
21 |
+
iou_thds = [float(f"{e:.2f}") for e in iou_thds]
|
22 |
+
pred_qid2data = defaultdict(list)
|
23 |
+
for d in submission:
|
24 |
+
pred_windows = d["pred_relevant_windows"][:max_pred_windows] \
|
25 |
+
if max_pred_windows is not None else d["pred_relevant_windows"]
|
26 |
+
qid = d["qid"]
|
27 |
+
for w in pred_windows:
|
28 |
+
pred_qid2data[qid].append({
|
29 |
+
"video-id": d["qid"], # in order to use the API
|
30 |
+
"t-start": w[0],
|
31 |
+
"t-end": w[1],
|
32 |
+
"score": w[2]
|
33 |
+
})
|
34 |
+
|
35 |
+
gt_qid2data = defaultdict(list)
|
36 |
+
for d in ground_truth:
|
37 |
+
gt_windows = d["relevant_windows"][:max_gt_windows] \
|
38 |
+
if max_gt_windows is not None else d["relevant_windows"]
|
39 |
+
qid = d["qid"]
|
40 |
+
for w in gt_windows:
|
41 |
+
gt_qid2data[qid].append({
|
42 |
+
"video-id": d["qid"],
|
43 |
+
"t-start": w[0],
|
44 |
+
"t-end": w[1]
|
45 |
+
})
|
46 |
+
qid2ap_list = {}
|
47 |
+
# start_time = time.time()
|
48 |
+
data_triples = [[qid, gt_qid2data[qid], pred_qid2data[qid]] for qid in pred_qid2data]
|
49 |
+
from functools import partial
|
50 |
+
compute_ap_from_triple = partial(
|
51 |
+
compute_average_precision_detection_wrapper, tiou_thresholds=iou_thds)
|
52 |
+
|
53 |
+
if num_workers > 1:
|
54 |
+
with mp.Pool(num_workers) as pool:
|
55 |
+
for qid, scores in pool.imap_unordered(compute_ap_from_triple, data_triples, chunksize=chunksize):
|
56 |
+
qid2ap_list[qid] = scores
|
57 |
+
else:
|
58 |
+
for data_triple in data_triples:
|
59 |
+
qid, scores = compute_ap_from_triple(data_triple)
|
60 |
+
qid2ap_list[qid] = scores
|
61 |
+
|
62 |
+
# print(f"compute_average_precision_detection {time.time() - start_time:.2f} seconds.")
|
63 |
+
ap_array = np.array(list(qid2ap_list.values())) # (#queries, #thd)
|
64 |
+
ap_thds = ap_array.mean(0) # mAP at different IoU thresholds.
|
65 |
+
iou_thd2ap = dict(zip([str(e) for e in iou_thds], ap_thds))
|
66 |
+
iou_thd2ap["average"] = np.mean(ap_thds)
|
67 |
+
# formatting
|
68 |
+
iou_thd2ap = {k: float(f"{100 * v:.2f}") for k, v in iou_thd2ap.items()}
|
69 |
+
return iou_thd2ap
|
70 |
+
|
71 |
+
|
72 |
+
def compute_mr_r1(submission, ground_truth, iou_thds=np.linspace(0.3, 0.95, 14)):
|
73 |
+
"""If a predicted segment has IoU >= iou_thd with one of the 1st GT segment, we define it positive"""
|
74 |
+
iou_thds = [float(f"{e:.2f}") for e in iou_thds]
|
75 |
+
pred_qid2window = {d["qid"]: d["pred_relevant_windows"][0][:2] for d in submission} # :2 rm scores
|
76 |
+
# gt_qid2window = {d["qid"]: d["relevant_windows"][0] for d in ground_truth}
|
77 |
+
gt_qid2window = {}
|
78 |
+
for d in ground_truth:
|
79 |
+
cur_gt_windows = d["relevant_windows"]
|
80 |
+
cur_qid = d["qid"]
|
81 |
+
cur_max_iou_idx = 0
|
82 |
+
if len(cur_gt_windows) > 0: # select the GT window that has the highest IoU
|
83 |
+
cur_ious = compute_temporal_iou_batch_cross(
|
84 |
+
np.array([pred_qid2window[cur_qid]]), np.array(d["relevant_windows"])
|
85 |
+
)[0]
|
86 |
+
cur_max_iou_idx = np.argmax(cur_ious)
|
87 |
+
gt_qid2window[cur_qid] = cur_gt_windows[cur_max_iou_idx]
|
88 |
+
|
89 |
+
qids = list(pred_qid2window.keys())
|
90 |
+
pred_windows = np.array([pred_qid2window[k] for k in qids]).astype(float)
|
91 |
+
gt_windows = np.array([gt_qid2window[k] for k in qids]).astype(float)
|
92 |
+
pred_gt_iou = compute_temporal_iou_batch_paired(pred_windows, gt_windows)
|
93 |
+
iou_thd2recall_at_one = {}
|
94 |
+
miou_at_one = float(f"{np.mean(pred_gt_iou) * 100:.2f}")
|
95 |
+
for thd in iou_thds:
|
96 |
+
iou_thd2recall_at_one[str(thd)] = float(f"{np.mean(pred_gt_iou >= thd) * 100:.2f}")
|
97 |
+
return iou_thd2recall_at_one, miou_at_one
|
98 |
+
|
99 |
+
|
100 |
+
def get_window_len(window):
|
101 |
+
return window[1] - window[0]
|
102 |
+
|
103 |
+
|
104 |
+
def get_data_by_range(submission, ground_truth, len_range):
|
105 |
+
""" keep queries with ground truth window length in the specified length range.
|
106 |
+
Args:
|
107 |
+
submission:
|
108 |
+
ground_truth:
|
109 |
+
len_range: [min_l (int), max_l (int)]. the range is (min_l, max_l], i.e., min_l < l <= max_l
|
110 |
+
"""
|
111 |
+
min_l, max_l = len_range
|
112 |
+
if min_l == 0 and max_l == 150: # min and max l in dataset
|
113 |
+
return submission, ground_truth
|
114 |
+
|
115 |
+
# only keep ground truth with windows in the specified length range
|
116 |
+
# if multiple GT windows exists, we only keep the ones in the range
|
117 |
+
ground_truth_in_range = []
|
118 |
+
gt_qids_in_range = set()
|
119 |
+
for d in ground_truth:
|
120 |
+
rel_windows_in_range = [
|
121 |
+
w for w in d["relevant_windows"] if min_l < get_window_len(w) <= max_l]
|
122 |
+
if len(rel_windows_in_range) > 0:
|
123 |
+
d = copy.deepcopy(d)
|
124 |
+
d["relevant_windows"] = rel_windows_in_range
|
125 |
+
ground_truth_in_range.append(d)
|
126 |
+
gt_qids_in_range.add(d["qid"])
|
127 |
+
|
128 |
+
# keep only submissions for ground_truth_in_range
|
129 |
+
submission_in_range = []
|
130 |
+
for d in submission:
|
131 |
+
if d["qid"] in gt_qids_in_range:
|
132 |
+
submission_in_range.append(copy.deepcopy(d))
|
133 |
+
|
134 |
+
return submission_in_range, ground_truth_in_range
|
135 |
+
|
136 |
+
|
137 |
+
def eval_moment_retrieval(submission, ground_truth, verbose=True):
|
138 |
+
length_ranges = [[0, 10], [10, 30], [30, 150], [0, 150], ] #
|
139 |
+
range_names = ["short", "middle", "long", "full"]
|
140 |
+
|
141 |
+
ret_metrics = {}
|
142 |
+
for l_range, name in zip(length_ranges, range_names):
|
143 |
+
if verbose:
|
144 |
+
start_time = time.time()
|
145 |
+
_submission, _ground_truth = get_data_by_range(submission, ground_truth, l_range)
|
146 |
+
print(f"{name}: {l_range}, {len(_ground_truth)}/{len(ground_truth)}="
|
147 |
+
f"{100*len(_ground_truth)/len(ground_truth):.2f} examples.")
|
148 |
+
if len(_ground_truth) == 0:
|
149 |
+
# ret_metrics[name] = {"MR-mAP": 0., "MR-R1": 0.}
|
150 |
+
dummy_dict = {}
|
151 |
+
for k in np.linspace(0.5, 0.95, 19):
|
152 |
+
dummy_dict[k] = 0.
|
153 |
+
dummy_dict['average'] = 0.
|
154 |
+
ret_metrics[name] = {"MR-mAP": dummy_dict, "MR-R1": dummy_dict}
|
155 |
+
else:
|
156 |
+
iou_thd2average_precision = compute_mr_ap(_submission, _ground_truth, num_workers=8, chunksize=50)
|
157 |
+
iou_thd2recall_at_one, miou_at_one = compute_mr_r1(_submission, _ground_truth)
|
158 |
+
ret_metrics[name] = {"MR-mIoU": miou_at_one,
|
159 |
+
"MR-mAP": iou_thd2average_precision,
|
160 |
+
"MR-R1": iou_thd2recall_at_one}
|
161 |
+
|
162 |
+
# iou_thd2average_precision = compute_mr_ap(_submission, _ground_truth, num_workers=8, chunksize=50)
|
163 |
+
# iou_thd2recall_at_one = compute_mr_r1(_submission, _ground_truth)
|
164 |
+
# ret_metrics[name] = {"MR-mAP": iou_thd2average_precision, "MR-R1": iou_thd2recall_at_one}
|
165 |
+
if verbose:
|
166 |
+
print(f"[eval_moment_retrieval] [{name}] {time.time() - start_time:.2f} seconds")
|
167 |
+
return ret_metrics
|
168 |
+
|
169 |
+
|
170 |
+
def compute_hl_hit1(qid2preds, qid2gt_scores_binary):
|
171 |
+
qid2max_scored_clip_idx = {k: np.argmax(v["pred_saliency_scores"]) for k, v in qid2preds.items()}
|
172 |
+
hit_scores = np.zeros((len(qid2preds), 3))
|
173 |
+
qids = list(qid2preds.keys())
|
174 |
+
for idx, qid in enumerate(qids):
|
175 |
+
pred_clip_idx = qid2max_scored_clip_idx[qid]
|
176 |
+
gt_scores_binary = qid2gt_scores_binary[qid] # (#clips, 3)
|
177 |
+
if pred_clip_idx < len(gt_scores_binary):
|
178 |
+
hit_scores[idx] = gt_scores_binary[pred_clip_idx]
|
179 |
+
# aggregate scores from 3 separate annotations (3 workers) by taking the max.
|
180 |
+
# then average scores from all queries.
|
181 |
+
hit_at_one = float(f"{100 * np.mean(np.max(hit_scores, 1)):.2f}")
|
182 |
+
return hit_at_one
|
183 |
+
|
184 |
+
|
185 |
+
def compute_hl_ap(qid2preds, qid2gt_scores_binary, num_workers=8, chunksize=50):
|
186 |
+
qid2pred_scores = {k: v["pred_saliency_scores"] for k, v in qid2preds.items()}
|
187 |
+
ap_scores = np.zeros((len(qid2preds), 3)) # (#preds, 3)
|
188 |
+
qids = list(qid2preds.keys())
|
189 |
+
input_tuples = []
|
190 |
+
for idx, qid in enumerate(qids):
|
191 |
+
for w_idx in range(3): # annotation score idx
|
192 |
+
y_true = qid2gt_scores_binary[qid][:, w_idx]
|
193 |
+
y_predict = np.array(qid2pred_scores[qid])
|
194 |
+
input_tuples.append((idx, w_idx, y_true, y_predict))
|
195 |
+
|
196 |
+
if num_workers > 1:
|
197 |
+
with mp.Pool(num_workers) as pool:
|
198 |
+
for idx, w_idx, score in pool.imap_unordered(
|
199 |
+
compute_ap_from_tuple, input_tuples, chunksize=chunksize):
|
200 |
+
ap_scores[idx, w_idx] = score
|
201 |
+
else:
|
202 |
+
for input_tuple in input_tuples:
|
203 |
+
idx, w_idx, score = compute_ap_from_tuple(input_tuple)
|
204 |
+
ap_scores[idx, w_idx] = score
|
205 |
+
|
206 |
+
# it's the same if we first average across different annotations, then average across queries
|
207 |
+
# since all queries have the same #annotations.
|
208 |
+
mean_ap = float(f"{100 * np.mean(ap_scores):.2f}")
|
209 |
+
return mean_ap
|
210 |
+
|
211 |
+
|
212 |
+
def compute_ap_from_tuple(input_tuple):
|
213 |
+
idx, w_idx, y_true, y_predict = input_tuple
|
214 |
+
if len(y_true) < len(y_predict):
|
215 |
+
# print(f"len(y_true) < len(y_predict) {len(y_true), len(y_predict)}")
|
216 |
+
y_predict = y_predict[:len(y_true)]
|
217 |
+
elif len(y_true) > len(y_predict):
|
218 |
+
# print(f"len(y_true) > len(y_predict) {len(y_true), len(y_predict)}")
|
219 |
+
_y_predict = np.zeros(len(y_true))
|
220 |
+
_y_predict[:len(y_predict)] = y_predict
|
221 |
+
y_predict = _y_predict
|
222 |
+
|
223 |
+
score = get_ap(y_true, y_predict)
|
224 |
+
return idx, w_idx, score
|
225 |
+
|
226 |
+
|
227 |
+
def mk_gt_scores(gt_data, clip_length=2):
|
228 |
+
"""gt_data, dict, """
|
229 |
+
num_clips = int(gt_data["duration"] / clip_length)
|
230 |
+
saliency_scores_full_video = np.zeros((num_clips, 3))
|
231 |
+
relevant_clip_ids = np.array(gt_data["relevant_clip_ids"]) # (#relevant_clip_ids, )
|
232 |
+
saliency_scores_relevant_clips = np.array(gt_data["saliency_scores"]) # (#relevant_clip_ids, 3)
|
233 |
+
saliency_scores_full_video[relevant_clip_ids] = saliency_scores_relevant_clips
|
234 |
+
return saliency_scores_full_video # (#clips_in_video, 3) the scores are in range [0, 4]
|
235 |
+
|
236 |
+
|
237 |
+
def eval_highlight(submission, ground_truth, verbose=True):
|
238 |
+
"""
|
239 |
+
Args:
|
240 |
+
submission:
|
241 |
+
ground_truth:
|
242 |
+
verbose:
|
243 |
+
"""
|
244 |
+
qid2preds = {d["qid"]: d for d in submission}
|
245 |
+
qid2gt_scores_full_range = {d["qid"]: mk_gt_scores(d) for d in ground_truth} # scores in range [0, 4]
|
246 |
+
# gt_saliency_score_min: int, in [0, 1, 2, 3, 4]. The minimum score for a positive clip.
|
247 |
+
gt_saliency_score_min_list = [2, 3, 4]
|
248 |
+
saliency_score_names = ["Fair", "Good", "VeryGood"]
|
249 |
+
highlight_det_metrics = {}
|
250 |
+
for gt_saliency_score_min, score_name in zip(gt_saliency_score_min_list, saliency_score_names):
|
251 |
+
start_time = time.time()
|
252 |
+
qid2gt_scores_binary = {
|
253 |
+
k: (v >= gt_saliency_score_min).astype(float)
|
254 |
+
for k, v in qid2gt_scores_full_range.items()} # scores in [0, 1]
|
255 |
+
hit_at_one = compute_hl_hit1(qid2preds, qid2gt_scores_binary)
|
256 |
+
mean_ap = compute_hl_ap(qid2preds, qid2gt_scores_binary)
|
257 |
+
highlight_det_metrics[f"HL-min-{score_name}"] = {"HL-mAP": mean_ap, "HL-Hit1": hit_at_one}
|
258 |
+
if verbose:
|
259 |
+
print(f"Calculating highlight scores with min score {gt_saliency_score_min} ({score_name})")
|
260 |
+
print(f"Time cost {time.time() - start_time:.2f} seconds")
|
261 |
+
return highlight_det_metrics
|
262 |
+
|
263 |
+
|
264 |
+
def eval_submission(submission, ground_truth, verbose=True, match_number=False, hl=False):
|
265 |
+
"""
|
266 |
+
Args:
|
267 |
+
submission: list(dict), each dict is {
|
268 |
+
qid: str,
|
269 |
+
query: str,
|
270 |
+
vid: str,
|
271 |
+
pred_relevant_windows: list([st, ed]),
|
272 |
+
pred_saliency_scores: list(float), len == #clips in video.
|
273 |
+
i.e., each clip in the video will have a saliency score.
|
274 |
+
}
|
275 |
+
ground_truth: list(dict), each dict is {
|
276 |
+
"qid": 7803,
|
277 |
+
"query": "Man in gray top walks from outside to inside.",
|
278 |
+
"duration": 150,
|
279 |
+
"vid": "RoripwjYFp8_360.0_510.0",
|
280 |
+
"relevant_clip_ids": [13, 14, 15, 16, 17]
|
281 |
+
"saliency_scores": [[4, 4, 2], [3, 4, 2], [2, 2, 3], [2, 2, 2], [0, 1, 3]]
|
282 |
+
each sublist corresponds to one clip in relevant_clip_ids.
|
283 |
+
The 3 elements in the sublist are scores from 3 different workers. The
|
284 |
+
scores are in [0, 1, 2, 3, 4], meaning [Very Bad, ..., Good, Very Good]
|
285 |
+
}
|
286 |
+
verbose:
|
287 |
+
match_number:
|
288 |
+
|
289 |
+
Returns:
|
290 |
+
|
291 |
+
"""
|
292 |
+
pred_qids = set([e["qid"] for e in submission])
|
293 |
+
gt_qids = set([e["qid"] for e in ground_truth])
|
294 |
+
# import pdb; pdb.set_trace()
|
295 |
+
if match_number:
|
296 |
+
assert pred_qids == gt_qids, \
|
297 |
+
f"qids in ground_truth and submission must match. " \
|
298 |
+
f"use `match_number=False` if you wish to disable this check"
|
299 |
+
else: # only leave the items that exists in both submission and ground_truth
|
300 |
+
shared_qids = pred_qids.intersection(gt_qids)
|
301 |
+
submission = [e for e in submission if e["qid"] in shared_qids]
|
302 |
+
ground_truth = [e for e in ground_truth if e["qid"] in shared_qids]
|
303 |
+
|
304 |
+
eval_metrics = {}
|
305 |
+
eval_metrics_brief = OrderedDict()
|
306 |
+
if "pred_relevant_windows" in submission[0]:
|
307 |
+
moment_ret_scores = eval_moment_retrieval(submission, ground_truth, verbose=verbose)
|
308 |
+
eval_metrics.update(moment_ret_scores)
|
309 |
+
moment_ret_scores_brief = {
|
310 |
+
"MR-full-mAP": moment_ret_scores["full"]["MR-mAP"]["average"],
|
311 |
+
"[email protected]": moment_ret_scores["full"]["MR-mAP"]["0.5"],
|
312 |
+
"[email protected]": moment_ret_scores["full"]["MR-mAP"]["0.75"],
|
313 |
+
"MR-short-mAP": moment_ret_scores["short"]["MR-mAP"]["average"],
|
314 |
+
"MR-middle-mAP": moment_ret_scores["middle"]["MR-mAP"]["average"],
|
315 |
+
"MR-long-mAP": moment_ret_scores["long"]["MR-mAP"]["average"],
|
316 |
+
"MR-full-mIoU": moment_ret_scores["full"]["MR-mIoU"],
|
317 |
+
"[email protected]": moment_ret_scores["full"]["MR-R1"]["0.3"],
|
318 |
+
"[email protected]": moment_ret_scores["full"]["MR-R1"]["0.5"],
|
319 |
+
"[email protected]": moment_ret_scores["full"]["MR-R1"]["0.7"],
|
320 |
+
}
|
321 |
+
eval_metrics_brief.update(
|
322 |
+
sorted([(k, v) for k, v in moment_ret_scores_brief.items()], key=lambda x: x[0]))
|
323 |
+
|
324 |
+
if "pred_saliency_scores" in submission[0] and hl:
|
325 |
+
highlight_det_scores = eval_highlight(
|
326 |
+
submission, ground_truth, verbose=verbose)
|
327 |
+
eval_metrics.update(highlight_det_scores)
|
328 |
+
highlight_det_scores_brief = dict([
|
329 |
+
(f"{k}-{sub_k.split('-')[1]}", v[sub_k])
|
330 |
+
for k, v in highlight_det_scores.items() for sub_k in v])
|
331 |
+
eval_metrics_brief.update(highlight_det_scores_brief)
|
332 |
+
|
333 |
+
# sort by keys
|
334 |
+
final_eval_metrics = OrderedDict()
|
335 |
+
final_eval_metrics["brief"] = eval_metrics_brief
|
336 |
+
final_eval_metrics.update(sorted([(k, v) for k, v in eval_metrics.items()], key=lambda x: x[0]))
|
337 |
+
return final_eval_metrics
|
338 |
+
|
339 |
+
|
340 |
+
def eval_main():
|
341 |
+
import argparse
|
342 |
+
parser = argparse.ArgumentParser(description="Moments and Highlights Evaluation Script")
|
343 |
+
parser.add_argument("--submission_path", type=str, help="path to generated prediction file")
|
344 |
+
parser.add_argument("--gt_path", type=str, help="path to GT file")
|
345 |
+
parser.add_argument("--save_path", type=str, help="path to save the results")
|
346 |
+
parser.add_argument("--not_verbose", action="store_true")
|
347 |
+
args = parser.parse_args()
|
348 |
+
|
349 |
+
verbose = not args.not_verbose
|
350 |
+
submission = load_jsonl(args.submission_path)
|
351 |
+
gt = load_jsonl(args.gt_path)
|
352 |
+
results = eval_submission(submission, gt, verbose=verbose)
|
353 |
+
if verbose:
|
354 |
+
print(json.dumps(results, indent=4))
|
355 |
+
|
356 |
+
with open(args.save_path, "w") as f:
|
357 |
+
f.write(json.dumps(results, indent=4))
|
358 |
+
|
359 |
+
|
360 |
+
if __name__ == '__main__':
|
361 |
+
eval_main()
|
third_party/cgdetr/standalone_eval/eval_sample.sh
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
# Usage: bash standalone_eval/eval_sample.sh
|
3 |
+
submission_path=standalone_eval/sample_val_preds.jsonl
|
4 |
+
gt_path=data/highlight_val_release.jsonl
|
5 |
+
save_path=standalone_eval/sample_val_preds_metrics.json
|
6 |
+
|
7 |
+
PYTHONPATH=$PYTHONPATH:. python standalone_eval/eval.py \
|
8 |
+
--submission_path ${submission_path} \
|
9 |
+
--gt_path ${gt_path} \
|
10 |
+
--save_path ${save_path}
|
third_party/cgdetr/standalone_eval/utils.py
ADDED
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copied from MMAction2
|
3 |
+
https://github.com/open-mmlab/mmaction2/blob/master/mmaction/core/evaluation/eval_detection.py
|
4 |
+
"""
|
5 |
+
import json
|
6 |
+
import numpy as np
|
7 |
+
from sklearn.metrics import precision_recall_curve
|
8 |
+
|
9 |
+
|
10 |
+
def load_jsonl(filename):
|
11 |
+
with open(filename, "r") as f:
|
12 |
+
return [json.loads(l.strip("\n")) for l in f.readlines()]
|
13 |
+
|
14 |
+
|
15 |
+
def compute_temporal_iou_batch_paired(pred_windows, gt_windows):
|
16 |
+
""" compute intersection-over-union along temporal axis for each pair of windows in pred_windows and gt_windows.
|
17 |
+
Args:
|
18 |
+
pred_windows: np.ndarray, (N, 2), [st (float), ed (float)] * N
|
19 |
+
gt_windows: np.ndarray, (N, 2), [st (float), ed (float)] * N
|
20 |
+
Returns:
|
21 |
+
iou (float): np.ndarray, (N, )
|
22 |
+
|
23 |
+
References:
|
24 |
+
for np.divide with zeros, see https://stackoverflow.com/a/37977222
|
25 |
+
"""
|
26 |
+
intersection = np.maximum(
|
27 |
+
0, np.minimum(pred_windows[:, 1], gt_windows[:, 1]) - np.maximum(pred_windows[:, 0], gt_windows[:, 0])
|
28 |
+
)
|
29 |
+
union = np.maximum(pred_windows[:, 1], gt_windows[:, 1]) \
|
30 |
+
- np.minimum(pred_windows[:, 0], gt_windows[:, 0]) # not the correct union though
|
31 |
+
return np.divide(intersection, union, out=np.zeros_like(intersection), where=union != 0)
|
32 |
+
|
33 |
+
|
34 |
+
def compute_temporal_iou_batch_cross(spans1, spans2):
|
35 |
+
"""
|
36 |
+
Args:
|
37 |
+
spans1: (N, 2) np.ndarray, each row defines a span [st, ed]
|
38 |
+
spans2: (M, 2) np.ndarray, ...
|
39 |
+
|
40 |
+
Returns:
|
41 |
+
iou: (N, M) np.ndarray
|
42 |
+
union: (N, M) np.ndarray
|
43 |
+
>>> spans1 = np.array([[0, 0.2, 0.9], [0.5, 1.0, 0.2]])
|
44 |
+
>>> spans2 = np.array([[0, 0.3], [0., 1.0]])
|
45 |
+
>>> compute_temporal_iou_batch_cross(spans1, spans2)
|
46 |
+
(tensor([[0.6667, 0.2000],
|
47 |
+
[0.0000, 0.5000]]),
|
48 |
+
tensor([[0.3000, 1.0000],
|
49 |
+
[0.8000, 1.0000]]))
|
50 |
+
"""
|
51 |
+
areas1 = spans1[:, 1] - spans1[:, 0] # (N, )
|
52 |
+
areas2 = spans2[:, 1] - spans2[:, 0] # (M, )
|
53 |
+
|
54 |
+
left = np.maximum(spans1[:, None, 0], spans2[None, :, 0]) # (N, M)
|
55 |
+
right = np.minimum(spans1[:, None, 1], spans2[None, :, 1]) # (N, M)
|
56 |
+
|
57 |
+
inter = np.clip(right - left, 0, None) # (N, M)
|
58 |
+
union = areas1[:, None] + areas2[None, :] - inter # (N, M)
|
59 |
+
|
60 |
+
iou = inter / union
|
61 |
+
return iou, union
|
62 |
+
|
63 |
+
|
64 |
+
def interpolated_precision_recall(precision, recall):
|
65 |
+
"""Interpolated AP - VOCdevkit from VOC 2011.
|
66 |
+
|
67 |
+
Args:
|
68 |
+
precision (np.ndarray): The precision of different thresholds.
|
69 |
+
recall (np.ndarray): The recall of different thresholds.
|
70 |
+
|
71 |
+
Returns:
|
72 |
+
float: Average precision score.
|
73 |
+
"""
|
74 |
+
mprecision = np.hstack([[0], precision, [0]])
|
75 |
+
mrecall = np.hstack([[0], recall, [1]])
|
76 |
+
for i in range(len(mprecision) - 1)[::-1]:
|
77 |
+
mprecision[i] = max(mprecision[i], mprecision[i + 1])
|
78 |
+
idx = np.where(mrecall[1::] != mrecall[0:-1])[0] + 1
|
79 |
+
ap = np.sum((mrecall[idx] - mrecall[idx - 1]) * mprecision[idx])
|
80 |
+
return ap
|
81 |
+
|
82 |
+
|
83 |
+
def compute_average_precision_detection(ground_truth,
|
84 |
+
prediction,
|
85 |
+
tiou_thresholds=np.linspace(
|
86 |
+
0.5, 0.95, 10)):
|
87 |
+
"""Compute average precision (detection task) between ground truth and
|
88 |
+
predictions data frames. If multiple predictions occurs for the same
|
89 |
+
predicted segment, only the one with highest score is matches as true
|
90 |
+
positive. This code is greatly inspired by Pascal VOC devkit.
|
91 |
+
|
92 |
+
Args:
|
93 |
+
ground_truth (list[dict]): List containing the ground truth instances
|
94 |
+
(dictionaries). Required keys are 'video-id', 't-start' and
|
95 |
+
't-end'.
|
96 |
+
prediction (list[dict]): List containing the prediction instances
|
97 |
+
(dictionaries). Required keys are: 'video-id', 't-start', 't-end'
|
98 |
+
and 'score'.
|
99 |
+
tiou_thresholds (np.ndarray): A 1darray indicates the temporal
|
100 |
+
intersection over union threshold, which is optional.
|
101 |
+
Default: ``np.linspace(0.5, 0.95, 10)``.
|
102 |
+
|
103 |
+
Returns:
|
104 |
+
Float: ap, Average precision score.
|
105 |
+
"""
|
106 |
+
num_thresholds = len(tiou_thresholds)
|
107 |
+
num_gts = len(ground_truth)
|
108 |
+
num_preds = len(prediction)
|
109 |
+
ap = np.zeros(num_thresholds)
|
110 |
+
if len(prediction) == 0:
|
111 |
+
return ap
|
112 |
+
|
113 |
+
num_positive = float(num_gts)
|
114 |
+
lock_gt = np.ones((num_thresholds, num_gts)) * -1
|
115 |
+
# Sort predictions by decreasing score order.
|
116 |
+
prediction.sort(key=lambda x: -x['score'])
|
117 |
+
# Initialize true positive and false positive vectors.
|
118 |
+
tp = np.zeros((num_thresholds, num_preds))
|
119 |
+
fp = np.zeros((num_thresholds, num_preds))
|
120 |
+
|
121 |
+
# Adaptation to query faster
|
122 |
+
ground_truth_by_videoid = {}
|
123 |
+
for i, item in enumerate(ground_truth):
|
124 |
+
item['index'] = i
|
125 |
+
ground_truth_by_videoid.setdefault(item['video-id'], []).append(item)
|
126 |
+
|
127 |
+
# Assigning true positive to truly grount truth instances.
|
128 |
+
for idx, pred in enumerate(prediction):
|
129 |
+
if pred['video-id'] in ground_truth_by_videoid:
|
130 |
+
gts = ground_truth_by_videoid[pred['video-id']]
|
131 |
+
else:
|
132 |
+
fp[:, idx] = 1
|
133 |
+
continue
|
134 |
+
|
135 |
+
_pred = np.array([[pred['t-start'], pred['t-end']], ])
|
136 |
+
_gt = np.array([[gt['t-start'], gt['t-end']] for gt in gts])
|
137 |
+
tiou_arr = compute_temporal_iou_batch_cross(_pred, _gt)[0]
|
138 |
+
|
139 |
+
tiou_arr = tiou_arr.reshape(-1)
|
140 |
+
# We would like to retrieve the predictions with highest tiou score.
|
141 |
+
tiou_sorted_idx = tiou_arr.argsort()[::-1]
|
142 |
+
for t_idx, tiou_threshold in enumerate(tiou_thresholds):
|
143 |
+
for j_idx in tiou_sorted_idx:
|
144 |
+
if tiou_arr[j_idx] < tiou_threshold:
|
145 |
+
fp[t_idx, idx] = 1
|
146 |
+
break
|
147 |
+
if lock_gt[t_idx, gts[j_idx]['index']] >= 0:
|
148 |
+
continue
|
149 |
+
# Assign as true positive after the filters above.
|
150 |
+
tp[t_idx, idx] = 1
|
151 |
+
lock_gt[t_idx, gts[j_idx]['index']] = idx
|
152 |
+
break
|
153 |
+
|
154 |
+
if fp[t_idx, idx] == 0 and tp[t_idx, idx] == 0:
|
155 |
+
fp[t_idx, idx] = 1
|
156 |
+
|
157 |
+
tp_cumsum = np.cumsum(tp, axis=1).astype(float)
|
158 |
+
fp_cumsum = np.cumsum(fp, axis=1).astype(float)
|
159 |
+
recall_cumsum = tp_cumsum / num_positive
|
160 |
+
|
161 |
+
precision_cumsum = tp_cumsum / (tp_cumsum + fp_cumsum)
|
162 |
+
|
163 |
+
for t_idx in range(len(tiou_thresholds)):
|
164 |
+
ap[t_idx] = interpolated_precision_recall(precision_cumsum[t_idx, :],
|
165 |
+
recall_cumsum[t_idx, :])
|
166 |
+
return ap
|
167 |
+
|
168 |
+
|
169 |
+
def get_ap(y_true, y_predict, interpolate=True, point_11=False):
|
170 |
+
"""
|
171 |
+
Average precision in different formats: (non-) interpolated and/or 11-point approximated
|
172 |
+
point_11=True and interpolate=True corresponds to the 11-point interpolated AP used in
|
173 |
+
the PASCAL VOC challenge up to the 2008 edition and has been verfied against the vlfeat implementation
|
174 |
+
The exact average precision (interpolate=False, point_11=False) corresponds to the one of vl_feat
|
175 |
+
|
176 |
+
:param y_true: list/ numpy vector of true labels in {0,1} for each element
|
177 |
+
:param y_predict: predicted score for each element
|
178 |
+
:param interpolate: Use interpolation?
|
179 |
+
:param point_11: Use 11-point approximation to average precision?
|
180 |
+
:return: average precision
|
181 |
+
|
182 |
+
ref: https://github.com/gyglim/video2gif_dataset/blob/master/v2g_evaluation/__init__.py
|
183 |
+
|
184 |
+
"""
|
185 |
+
# Check inputs
|
186 |
+
assert len(y_true) == len(y_predict), "Prediction and ground truth need to be of the same length"
|
187 |
+
if len(set(y_true)) == 1:
|
188 |
+
if y_true[0] == 0:
|
189 |
+
return 0 # True labels are all zeros
|
190 |
+
# raise ValueError('True labels cannot all be zero')
|
191 |
+
else:
|
192 |
+
return 1
|
193 |
+
else:
|
194 |
+
assert sorted(set(y_true)) == [0, 1], "Ground truth can only contain elements {0,1}"
|
195 |
+
|
196 |
+
# Compute precision and recall
|
197 |
+
precision, recall, _ = precision_recall_curve(y_true, y_predict)
|
198 |
+
recall = recall.astype(np.float32)
|
199 |
+
|
200 |
+
if interpolate: # Compute the interpolated precision
|
201 |
+
for i in range(1, len(precision)):
|
202 |
+
precision[i] = max(precision[i - 1], precision[i])
|
203 |
+
|
204 |
+
if point_11: # Compute the 11-point approximated AP
|
205 |
+
precision_11 = [precision[np.where(recall >= t)[0][-1]] for t in np.arange(0, 1.01, 0.1)]
|
206 |
+
return np.mean(precision_11)
|
207 |
+
else: # Compute the AP using precision at every additionally recalled sample
|
208 |
+
indices = np.where(np.diff(recall))
|
209 |
+
return np.mean(precision[indices])
|
third_party/cgdetr/utils/basic_utils.py
ADDED
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import zipfile
|
4 |
+
import numpy as np
|
5 |
+
import pickle
|
6 |
+
from collections import OrderedDict, Counter
|
7 |
+
import pandas as pd
|
8 |
+
|
9 |
+
|
10 |
+
def load_pickle(filename):
|
11 |
+
with open(filename, "rb") as f:
|
12 |
+
return pickle.load(f)
|
13 |
+
|
14 |
+
|
15 |
+
def save_pickle(data, filename):
|
16 |
+
with open(filename, "wb") as f:
|
17 |
+
pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL)
|
18 |
+
|
19 |
+
|
20 |
+
def load_json(filename):
|
21 |
+
with open(filename, "r") as f:
|
22 |
+
return json.load(f)
|
23 |
+
|
24 |
+
|
25 |
+
def save_json(data, filename, save_pretty=False, sort_keys=False):
|
26 |
+
with open(filename, "w") as f:
|
27 |
+
if save_pretty:
|
28 |
+
f.write(json.dumps(data, indent=4, sort_keys=sort_keys))
|
29 |
+
else:
|
30 |
+
json.dump(data, f)
|
31 |
+
|
32 |
+
|
33 |
+
def load_jsonl(filename):
|
34 |
+
with open(filename, "r") as f:
|
35 |
+
return [json.loads(l.strip("\n")) for l in f.readlines()]
|
36 |
+
|
37 |
+
|
38 |
+
def save_jsonl(data, filename):
|
39 |
+
"""data is a list"""
|
40 |
+
with open(filename, "w") as f:
|
41 |
+
f.write("\n".join([json.dumps(e) for e in data]))
|
42 |
+
|
43 |
+
|
44 |
+
def save_lines(list_of_str, filepath):
|
45 |
+
with open(filepath, "w") as f:
|
46 |
+
f.write("\n".join(list_of_str))
|
47 |
+
|
48 |
+
|
49 |
+
def read_lines(filepath):
|
50 |
+
with open(filepath, "r") as f:
|
51 |
+
return [e.strip("\n") for e in f.readlines()]
|
52 |
+
|
53 |
+
|
54 |
+
def mkdirp(p):
|
55 |
+
if not os.path.exists(p):
|
56 |
+
os.makedirs(p)
|
57 |
+
|
58 |
+
|
59 |
+
def flat_list_of_lists(l):
|
60 |
+
"""flatten a list of lists [[1,2], [3,4]] to [1,2,3,4]"""
|
61 |
+
return [item for sublist in l for item in sublist]
|
62 |
+
|
63 |
+
|
64 |
+
def convert_to_seconds(hms_time):
|
65 |
+
""" convert '00:01:12' to 72 seconds.
|
66 |
+
:hms_time (str): time in comma separated string, e.g. '00:01:12'
|
67 |
+
:return (int): time in seconds, e.g. 72
|
68 |
+
"""
|
69 |
+
times = [float(t) for t in hms_time.split(":")]
|
70 |
+
return times[0] * 3600 + times[1] * 60 + times[2]
|
71 |
+
|
72 |
+
|
73 |
+
def get_video_name_from_url(url):
|
74 |
+
return url.split("/")[-1][:-4]
|
75 |
+
|
76 |
+
|
77 |
+
def merge_dicts(list_dicts):
|
78 |
+
merged_dict = list_dicts[0].copy()
|
79 |
+
for i in range(1, len(list_dicts)):
|
80 |
+
merged_dict.update(list_dicts[i])
|
81 |
+
return merged_dict
|
82 |
+
|
83 |
+
|
84 |
+
def l2_normalize_np_array(np_array, eps=1e-5):
|
85 |
+
"""np_array: np.ndarray, (*, D), where the last dim will be normalized"""
|
86 |
+
return np_array / (np.linalg.norm(np_array, axis=-1, keepdims=True) + eps)
|
87 |
+
|
88 |
+
|
89 |
+
def make_zipfile(src_dir, save_path, enclosing_dir="", exclude_dirs=None, exclude_extensions=None,
|
90 |
+
exclude_dirs_substring=None):
|
91 |
+
"""make a zip file of root_dir, save it to save_path.
|
92 |
+
exclude_paths will be excluded if it is a subdir of root_dir.
|
93 |
+
An enclosing_dir is added is specified.
|
94 |
+
"""
|
95 |
+
abs_src = os.path.abspath(src_dir)
|
96 |
+
with zipfile.ZipFile(save_path, "w") as zf:
|
97 |
+
for dirname, subdirs, files in os.walk(src_dir):
|
98 |
+
if exclude_dirs is not None:
|
99 |
+
for e_p in exclude_dirs:
|
100 |
+
if e_p in subdirs:
|
101 |
+
subdirs.remove(e_p)
|
102 |
+
if exclude_dirs_substring is not None:
|
103 |
+
to_rm = []
|
104 |
+
for d in subdirs:
|
105 |
+
if exclude_dirs_substring in d:
|
106 |
+
to_rm.append(d)
|
107 |
+
for e in to_rm:
|
108 |
+
subdirs.remove(e)
|
109 |
+
arcname = os.path.join(enclosing_dir, dirname[len(abs_src) + 1:])
|
110 |
+
zf.write(dirname, arcname)
|
111 |
+
for filename in files:
|
112 |
+
if exclude_extensions is not None:
|
113 |
+
if os.path.splitext(filename)[1] in exclude_extensions:
|
114 |
+
continue # do not zip it
|
115 |
+
absname = os.path.join(dirname, filename)
|
116 |
+
arcname = os.path.join(enclosing_dir, absname[len(abs_src) + 1:])
|
117 |
+
zf.write(absname, arcname)
|
118 |
+
|
119 |
+
|
120 |
+
class AverageMeter(object):
|
121 |
+
"""Computes and stores the average and current/max/min value"""
|
122 |
+
def __init__(self):
|
123 |
+
self.val = 0
|
124 |
+
self.avg = 0
|
125 |
+
self.sum = 0
|
126 |
+
self.count = 0
|
127 |
+
self.max = -1e10
|
128 |
+
self.min = 1e10
|
129 |
+
self.reset()
|
130 |
+
|
131 |
+
def reset(self):
|
132 |
+
self.val = 0
|
133 |
+
self.avg = 0
|
134 |
+
self.sum = 0
|
135 |
+
self.count = 0
|
136 |
+
self.max = -1e10
|
137 |
+
self.min = 1e10
|
138 |
+
|
139 |
+
def update(self, val, n=1):
|
140 |
+
self.max = max(val, self.max)
|
141 |
+
self.min = min(val, self.min)
|
142 |
+
self.val = val
|
143 |
+
self.sum += val * n
|
144 |
+
self.count += n
|
145 |
+
self.avg = self.sum / self.count
|
146 |
+
|
147 |
+
|
148 |
+
def dissect_by_lengths(np_array, lengths, dim=0, assert_equal=True):
|
149 |
+
"""Dissect an array (N, D) into a list a sub-array,
|
150 |
+
np_array.shape[0] == sum(lengths), Output is a list of nd arrays, singlton dimention is kept"""
|
151 |
+
if assert_equal:
|
152 |
+
assert len(np_array) == sum(lengths)
|
153 |
+
length_indices = [0, ]
|
154 |
+
for i in range(len(lengths)):
|
155 |
+
length_indices.append(length_indices[i] + lengths[i])
|
156 |
+
if dim == 0:
|
157 |
+
array_list = [np_array[length_indices[i]:length_indices[i+1]] for i in range(len(lengths))]
|
158 |
+
elif dim == 1:
|
159 |
+
array_list = [np_array[:, length_indices[i]:length_indices[i + 1]] for i in range(len(lengths))]
|
160 |
+
elif dim == 2:
|
161 |
+
array_list = [np_array[:, :, length_indices[i]:length_indices[i + 1]] for i in range(len(lengths))]
|
162 |
+
else:
|
163 |
+
raise NotImplementedError
|
164 |
+
return array_list
|
165 |
+
|
166 |
+
|
167 |
+
def get_ratio_from_counter(counter_obj, threshold=200):
|
168 |
+
keys = counter_obj.keys()
|
169 |
+
values = counter_obj.values()
|
170 |
+
filtered_values = [counter_obj[k] for k in keys if k > threshold]
|
171 |
+
return float(sum(filtered_values)) / sum(values)
|
172 |
+
|
173 |
+
|
174 |
+
def get_counter_dist(counter_object, sort_type="none"):
|
175 |
+
_sum = sum(counter_object.values())
|
176 |
+
dist = {k: float(f"{100 * v / _sum:.2f}") for k, v in counter_object.items()}
|
177 |
+
if sort_type == "value":
|
178 |
+
dist = OrderedDict(sorted(dist.items(), reverse=True))
|
179 |
+
return dist
|
180 |
+
|
181 |
+
|
182 |
+
def get_show_name(vid_name):
|
183 |
+
"""
|
184 |
+
get tvshow name from vid_name
|
185 |
+
:param vid_name: video clip name
|
186 |
+
:return: tvshow name
|
187 |
+
"""
|
188 |
+
show_list = ["friends", "met", "castle", "house", "grey"]
|
189 |
+
vid_name_prefix = vid_name.split("_")[0]
|
190 |
+
show_name = vid_name_prefix if vid_name_prefix in show_list else "bbt"
|
191 |
+
return show_name
|
192 |
+
|
193 |
+
|
194 |
+
def get_abspaths_by_ext(dir_path, ext=(".jpg",)):
|
195 |
+
"""Get absolute paths to files in dir_path with extensions specified by ext.
|
196 |
+
Note this function does work recursively.
|
197 |
+
"""
|
198 |
+
if isinstance(ext, list):
|
199 |
+
ext = tuple(ext)
|
200 |
+
if isinstance(ext, str):
|
201 |
+
ext = tuple([ext, ])
|
202 |
+
filepaths = [os.path.join(root, name)
|
203 |
+
for root, dirs, files in os.walk(dir_path)
|
204 |
+
for name in files
|
205 |
+
if name.endswith(tuple(ext))]
|
206 |
+
return filepaths
|
207 |
+
|
208 |
+
|
209 |
+
def get_basename_no_ext(path):
|
210 |
+
""" '/data/movienet/240p_keyframe_feats/tt7672188.npz' --> 'tt7672188' """
|
211 |
+
return os.path.splitext(os.path.split(path)[1])[0]
|
212 |
+
|
213 |
+
|
214 |
+
def dict_to_markdown(d, max_str_len=120):
|
215 |
+
# convert list into its str representation
|
216 |
+
d = {k: v.__repr__() if isinstance(v, list) else v for k, v in d.items()}
|
217 |
+
# truncate string that is longer than max_str_len
|
218 |
+
if max_str_len is not None:
|
219 |
+
d = {k: v[-max_str_len:] if isinstance(v, str) else v for k, v in d.items()}
|
220 |
+
return pd.DataFrame(d, index=[0]).transpose().to_markdown()
|
221 |
+
|