Spaces:
Runtime error
Runtime error
wjf5203
commited on
Commit
•
2aac0e2
1
Parent(s):
acefb81
Add application file
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .DS_Store +0 -0
- GLEE/.DS_Store +0 -0
- GLEE/clip_vit_base_patch32/config.json +157 -0
- GLEE/clip_vit_base_patch32/merges.txt +0 -0
- GLEE/clip_vit_base_patch32/preprocessor_config.json +19 -0
- GLEE/clip_vit_base_patch32/pytorch_model.bin +3 -0
- GLEE/clip_vit_base_patch32/special_tokens_map.json +1 -0
- GLEE/clip_vit_base_patch32/tokenizer.json +0 -0
- GLEE/clip_vit_base_patch32/tokenizer_config.json +1 -0
- GLEE/clip_vit_base_patch32/vocab.json +0 -0
- GLEE/configs/R50.yaml +71 -0
- GLEE/configs/SwinL.yaml +79 -0
- GLEE/glee/.DS_Store +0 -0
- GLEE/glee/__init__.py +12 -0
- GLEE/glee/backbone/__init__.py +7 -0
- GLEE/glee/backbone/backbone.py +51 -0
- GLEE/glee/backbone/build.py +11 -0
- GLEE/glee/backbone/davit.py +623 -0
- GLEE/glee/backbone/eva01.py +676 -0
- GLEE/glee/backbone/eva02-dino.py +598 -0
- GLEE/glee/backbone/eva02.py +647 -0
- GLEE/glee/backbone/eva_01_utils.py +222 -0
- GLEE/glee/backbone/eva_02_utils.py +356 -0
- GLEE/glee/backbone/internimage.py +737 -0
- GLEE/glee/backbone/registry.py +14 -0
- GLEE/glee/backbone/resnet.py +731 -0
- GLEE/glee/backbone/swin.py +783 -0
- GLEE/glee/backbone/vit.py +472 -0
- GLEE/glee/backbone/vit_utils.py +222 -0
- GLEE/glee/config.py +387 -0
- GLEE/glee/config_deeplab.py +28 -0
- GLEE/glee/models/.DS_Store +0 -0
- GLEE/glee/models/glee_model.py +296 -0
- GLEE/glee/models/pixel_decoder/__init__.py +1 -0
- GLEE/glee/models/pixel_decoder/__pycache__/__init__.cpython-38.pyc +0 -0
- GLEE/glee/models/pixel_decoder/__pycache__/__init__.cpython-39.pyc +0 -0
- GLEE/glee/models/pixel_decoder/__pycache__/early_fusion.cpython-38.pyc +0 -0
- GLEE/glee/models/pixel_decoder/__pycache__/early_fusion.cpython-39.pyc +0 -0
- GLEE/glee/models/pixel_decoder/__pycache__/maskdino_encoder.cpython-38.pyc +0 -0
- GLEE/glee/models/pixel_decoder/__pycache__/maskdino_encoder.cpython-39.pyc +0 -0
- GLEE/glee/models/pixel_decoder/__pycache__/position_encoding.cpython-38.pyc +0 -0
- GLEE/glee/models/pixel_decoder/__pycache__/position_encoding.cpython-39.pyc +0 -0
- GLEE/glee/models/pixel_decoder/early_fusion.py +230 -0
- GLEE/glee/models/pixel_decoder/maskdino_encoder.py +463 -0
- GLEE/glee/models/pixel_decoder/ops/functions/__init__.py +13 -0
- GLEE/glee/models/pixel_decoder/ops/functions/__pycache__/__init__.cpython-38.pyc +0 -0
- GLEE/glee/models/pixel_decoder/ops/functions/__pycache__/__init__.cpython-39.pyc +0 -0
- GLEE/glee/models/pixel_decoder/ops/functions/__pycache__/ms_deform_attn_func.cpython-38.pyc +0 -0
- GLEE/glee/models/pixel_decoder/ops/functions/__pycache__/ms_deform_attn_func.cpython-39.pyc +0 -0
- GLEE/glee/models/pixel_decoder/ops/functions/ms_deform_attn_func.py +72 -0
.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
GLEE/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
GLEE/clip_vit_base_patch32/config.json
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "openai/clip-vit-base-patch32",
|
3 |
+
"architectures": [
|
4 |
+
"CLIPModel"
|
5 |
+
],
|
6 |
+
"initializer_factor": 1.0,
|
7 |
+
"logit_scale_init_value": 2.6592,
|
8 |
+
"model_type": "clip",
|
9 |
+
"projection_dim": 512,
|
10 |
+
"text_config": {
|
11 |
+
"_name_or_path": "",
|
12 |
+
"add_cross_attention": false,
|
13 |
+
"architectures": null,
|
14 |
+
"attention_dropout": 0.0,
|
15 |
+
"bad_words_ids": null,
|
16 |
+
"bos_token_id": 0,
|
17 |
+
"chunk_size_feed_forward": 0,
|
18 |
+
"cross_attention_hidden_size": null,
|
19 |
+
"decoder_start_token_id": null,
|
20 |
+
"diversity_penalty": 0.0,
|
21 |
+
"do_sample": false,
|
22 |
+
"dropout": 0.0,
|
23 |
+
"early_stopping": false,
|
24 |
+
"encoder_no_repeat_ngram_size": 0,
|
25 |
+
"eos_token_id": 2,
|
26 |
+
"finetuning_task": null,
|
27 |
+
"forced_bos_token_id": null,
|
28 |
+
"forced_eos_token_id": null,
|
29 |
+
"hidden_act": "quick_gelu",
|
30 |
+
"hidden_size": 512,
|
31 |
+
"id2label": {
|
32 |
+
"0": "LABEL_0",
|
33 |
+
"1": "LABEL_1"
|
34 |
+
},
|
35 |
+
"initializer_factor": 1.0,
|
36 |
+
"initializer_range": 0.02,
|
37 |
+
"intermediate_size": 2048,
|
38 |
+
"is_decoder": false,
|
39 |
+
"is_encoder_decoder": false,
|
40 |
+
"label2id": {
|
41 |
+
"LABEL_0": 0,
|
42 |
+
"LABEL_1": 1
|
43 |
+
},
|
44 |
+
"layer_norm_eps": 1e-05,
|
45 |
+
"length_penalty": 1.0,
|
46 |
+
"max_length": 20,
|
47 |
+
"max_position_embeddings": 77,
|
48 |
+
"min_length": 0,
|
49 |
+
"model_type": "clip_text_model",
|
50 |
+
"no_repeat_ngram_size": 0,
|
51 |
+
"num_attention_heads": 8,
|
52 |
+
"num_beam_groups": 1,
|
53 |
+
"num_beams": 1,
|
54 |
+
"num_hidden_layers": 12,
|
55 |
+
"num_return_sequences": 1,
|
56 |
+
"output_attentions": false,
|
57 |
+
"output_hidden_states": false,
|
58 |
+
"output_scores": false,
|
59 |
+
"pad_token_id": 1,
|
60 |
+
"prefix": null,
|
61 |
+
"projection_dim": 512,
|
62 |
+
"problem_type": null,
|
63 |
+
"pruned_heads": {},
|
64 |
+
"remove_invalid_values": false,
|
65 |
+
"repetition_penalty": 1.0,
|
66 |
+
"return_dict": true,
|
67 |
+
"return_dict_in_generate": false,
|
68 |
+
"sep_token_id": null,
|
69 |
+
"task_specific_params": null,
|
70 |
+
"temperature": 1.0,
|
71 |
+
"tie_encoder_decoder": false,
|
72 |
+
"tie_word_embeddings": true,
|
73 |
+
"tokenizer_class": null,
|
74 |
+
"top_k": 50,
|
75 |
+
"top_p": 1.0,
|
76 |
+
"torch_dtype": null,
|
77 |
+
"torchscript": false,
|
78 |
+
"transformers_version": "4.16.0.dev0",
|
79 |
+
"use_bfloat16": false,
|
80 |
+
"vocab_size": 49408
|
81 |
+
},
|
82 |
+
"text_config_dict": null,
|
83 |
+
"transformers_version": null,
|
84 |
+
"vision_config": {
|
85 |
+
"_name_or_path": "",
|
86 |
+
"add_cross_attention": false,
|
87 |
+
"architectures": null,
|
88 |
+
"attention_dropout": 0.0,
|
89 |
+
"bad_words_ids": null,
|
90 |
+
"bos_token_id": null,
|
91 |
+
"chunk_size_feed_forward": 0,
|
92 |
+
"cross_attention_hidden_size": null,
|
93 |
+
"decoder_start_token_id": null,
|
94 |
+
"diversity_penalty": 0.0,
|
95 |
+
"do_sample": false,
|
96 |
+
"dropout": 0.0,
|
97 |
+
"early_stopping": false,
|
98 |
+
"encoder_no_repeat_ngram_size": 0,
|
99 |
+
"eos_token_id": null,
|
100 |
+
"finetuning_task": null,
|
101 |
+
"forced_bos_token_id": null,
|
102 |
+
"forced_eos_token_id": null,
|
103 |
+
"hidden_act": "quick_gelu",
|
104 |
+
"hidden_size": 768,
|
105 |
+
"id2label": {
|
106 |
+
"0": "LABEL_0",
|
107 |
+
"1": "LABEL_1"
|
108 |
+
},
|
109 |
+
"image_size": 224,
|
110 |
+
"initializer_factor": 1.0,
|
111 |
+
"initializer_range": 0.02,
|
112 |
+
"intermediate_size": 3072,
|
113 |
+
"is_decoder": false,
|
114 |
+
"is_encoder_decoder": false,
|
115 |
+
"label2id": {
|
116 |
+
"LABEL_0": 0,
|
117 |
+
"LABEL_1": 1
|
118 |
+
},
|
119 |
+
"layer_norm_eps": 1e-05,
|
120 |
+
"length_penalty": 1.0,
|
121 |
+
"max_length": 20,
|
122 |
+
"min_length": 0,
|
123 |
+
"model_type": "clip_vision_model",
|
124 |
+
"no_repeat_ngram_size": 0,
|
125 |
+
"num_attention_heads": 12,
|
126 |
+
"num_beam_groups": 1,
|
127 |
+
"num_beams": 1,
|
128 |
+
"num_hidden_layers": 12,
|
129 |
+
"num_return_sequences": 1,
|
130 |
+
"output_attentions": false,
|
131 |
+
"output_hidden_states": false,
|
132 |
+
"output_scores": false,
|
133 |
+
"pad_token_id": null,
|
134 |
+
"patch_size": 32,
|
135 |
+
"prefix": null,
|
136 |
+
"projection_dim" : 512,
|
137 |
+
"problem_type": null,
|
138 |
+
"pruned_heads": {},
|
139 |
+
"remove_invalid_values": false,
|
140 |
+
"repetition_penalty": 1.0,
|
141 |
+
"return_dict": true,
|
142 |
+
"return_dict_in_generate": false,
|
143 |
+
"sep_token_id": null,
|
144 |
+
"task_specific_params": null,
|
145 |
+
"temperature": 1.0,
|
146 |
+
"tie_encoder_decoder": false,
|
147 |
+
"tie_word_embeddings": true,
|
148 |
+
"tokenizer_class": null,
|
149 |
+
"top_k": 50,
|
150 |
+
"top_p": 1.0,
|
151 |
+
"torch_dtype": null,
|
152 |
+
"torchscript": false,
|
153 |
+
"transformers_version": "4.16.0.dev0",
|
154 |
+
"use_bfloat16": false
|
155 |
+
},
|
156 |
+
"vision_config_dict": null
|
157 |
+
}
|
GLEE/clip_vit_base_patch32/merges.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
GLEE/clip_vit_base_patch32/preprocessor_config.json
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"crop_size": 224,
|
3 |
+
"do_center_crop": true,
|
4 |
+
"do_normalize": true,
|
5 |
+
"do_resize": true,
|
6 |
+
"feature_extractor_type": "CLIPFeatureExtractor",
|
7 |
+
"image_mean": [
|
8 |
+
0.48145466,
|
9 |
+
0.4578275,
|
10 |
+
0.40821073
|
11 |
+
],
|
12 |
+
"image_std": [
|
13 |
+
0.26862954,
|
14 |
+
0.26130258,
|
15 |
+
0.27577711
|
16 |
+
],
|
17 |
+
"resample": 3,
|
18 |
+
"size": 224
|
19 |
+
}
|
GLEE/clip_vit_base_patch32/pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a63082132ba4f97a80bea76823f544493bffa8082296d62d71581a4feff1576f
|
3 |
+
size 605247071
|
GLEE/clip_vit_base_patch32/special_tokens_map.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"bos_token": {"content": "<|startoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "eos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "unk_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "pad_token": "<|endoftext|>"}
|
GLEE/clip_vit_base_patch32/tokenizer.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
GLEE/clip_vit_base_patch32/tokenizer_config.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"unk_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "bos_token": {"content": "<|startoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "eos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "pad_token": "<|endoftext|>", "add_prefix_space": false, "errors": "replace", "do_lower_case": true, "name_or_path": "./clip_ViT_B_32/"}
|
GLEE/clip_vit_base_patch32/vocab.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
GLEE/configs/R50.yaml
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MODEL:
|
2 |
+
META_ARCHITECTURE: "GLEE"
|
3 |
+
MASK_ON: True
|
4 |
+
BACKBONE:
|
5 |
+
FREEZE_AT: 0
|
6 |
+
NAME: "build_resnet_backbone"
|
7 |
+
PIXEL_MEAN: [123.675, 116.280, 103.530]
|
8 |
+
PIXEL_STD: [58.395, 57.120, 57.375]
|
9 |
+
RESNETS:
|
10 |
+
DEPTH: 50
|
11 |
+
STEM_TYPE: "basic" # not used
|
12 |
+
STEM_OUT_CHANNELS: 64
|
13 |
+
STRIDE_IN_1X1: False
|
14 |
+
OUT_FEATURES: ["res2", "res3", "res4", "res5"]
|
15 |
+
# NORM: "SyncBN"
|
16 |
+
RES5_MULTI_GRID: [1, 1, 1] # not used
|
17 |
+
SEM_SEG_HEAD:
|
18 |
+
NAME: "MaskDINOHead"
|
19 |
+
IGNORE_VALUE: 255
|
20 |
+
NUM_CLASSES: 80
|
21 |
+
LOSS_WEIGHT: 1.0
|
22 |
+
CONVS_DIM: 256
|
23 |
+
MASK_DIM: 256
|
24 |
+
NORM: "GN"
|
25 |
+
# pixel decoder
|
26 |
+
PIXEL_DECODER_NAME: "MaskDINOEncoder"
|
27 |
+
DIM_FEEDFORWARD: 2048
|
28 |
+
NUM_FEATURE_LEVELS: 3
|
29 |
+
TOTAL_NUM_FEATURE_LEVELS: 4
|
30 |
+
IN_FEATURES: ["res2", "res3", "res4", "res5"]
|
31 |
+
DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES: ["res3", "res4", "res5"]
|
32 |
+
COMMON_STRIDE: 4
|
33 |
+
TRANSFORMER_ENC_LAYERS: 6
|
34 |
+
FEATURE_ORDER: "low2high"
|
35 |
+
MaskDINO:
|
36 |
+
TRANSFORMER_DECODER_NAME: "MaskDINODecoder"
|
37 |
+
DEEP_SUPERVISION: True
|
38 |
+
NO_OBJECT_WEIGHT: 0.1
|
39 |
+
CLASS_WEIGHT: 4.0
|
40 |
+
MASK_WEIGHT: 5.0
|
41 |
+
DICE_WEIGHT: 5.0
|
42 |
+
BOX_WEIGHT: 5.0
|
43 |
+
GIOU_WEIGHT: 2.0
|
44 |
+
HIDDEN_DIM: 256
|
45 |
+
NUM_OBJECT_QUERIES: 300
|
46 |
+
NHEADS: 8
|
47 |
+
DROPOUT: 0.0
|
48 |
+
DIM_FEEDFORWARD: 2048
|
49 |
+
ENC_LAYERS: 0
|
50 |
+
PRE_NORM: False
|
51 |
+
ENFORCE_INPUT_PROJ: False
|
52 |
+
SIZE_DIVISIBILITY: 32
|
53 |
+
DEC_LAYERS: 9 # 9+1, 9 decoder layers, add one for the loss on learnable query
|
54 |
+
TRAIN_NUM_POINTS: 12544
|
55 |
+
OVERSAMPLE_RATIO: 3.0
|
56 |
+
IMPORTANCE_SAMPLE_RATIO: 0.75
|
57 |
+
INITIAL_PRED: True
|
58 |
+
TWO_STAGE: True
|
59 |
+
DN: "standard"
|
60 |
+
DN_NUM: 100
|
61 |
+
INITIALIZE_BOX_TYPE: "no"
|
62 |
+
TEST:
|
63 |
+
SEMANTIC_ON: False
|
64 |
+
INSTANCE_ON: True
|
65 |
+
PANOPTIC_ON: False
|
66 |
+
OVERLAP_THRESHOLD: 0.8
|
67 |
+
OBJECT_MASK_THRESHOLD: 0.25
|
68 |
+
TEXT:
|
69 |
+
ARCH: clip_teacher
|
70 |
+
LANGUAGE_BACKBONE:
|
71 |
+
LANG_DIM: 512
|
GLEE/configs/SwinL.yaml
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MODEL:
|
2 |
+
META_ARCHITECTURE: "GLEE"
|
3 |
+
MASK_ON: True
|
4 |
+
BACKBONE:
|
5 |
+
NAME: "D2SwinTransformer"
|
6 |
+
SWIN:
|
7 |
+
EMBED_DIM: 192
|
8 |
+
DEPTHS: [2, 2, 18, 2]
|
9 |
+
NUM_HEADS: [6, 12, 24, 48]
|
10 |
+
WINDOW_SIZE: 12
|
11 |
+
APE: False
|
12 |
+
DROP_PATH_RATE: 0.3
|
13 |
+
PATCH_NORM: True
|
14 |
+
PRETRAIN_IMG_SIZE: 384
|
15 |
+
PIXEL_MEAN: [123.675, 116.280, 103.530]
|
16 |
+
PIXEL_STD: [58.395, 57.120, 57.375]
|
17 |
+
RESNETS:
|
18 |
+
DEPTH: 50
|
19 |
+
STEM_TYPE: "basic" # not used
|
20 |
+
STEM_OUT_CHANNELS: 64
|
21 |
+
STRIDE_IN_1X1: False
|
22 |
+
OUT_FEATURES: ["res2", "res3", "res4", "res5"]
|
23 |
+
# NORM: "SyncBN"
|
24 |
+
RES5_MULTI_GRID: [1, 1, 1] # not used
|
25 |
+
SEM_SEG_HEAD:
|
26 |
+
NAME: "MaskDINOHead"
|
27 |
+
IGNORE_VALUE: 255
|
28 |
+
NUM_CLASSES: 80
|
29 |
+
LOSS_WEIGHT: 1.0
|
30 |
+
CONVS_DIM: 256
|
31 |
+
MASK_DIM: 256
|
32 |
+
NORM: "GN"
|
33 |
+
# pixel decoder
|
34 |
+
PIXEL_DECODER_NAME: "MaskDINOEncoder"
|
35 |
+
DIM_FEEDFORWARD: 2048
|
36 |
+
NUM_FEATURE_LEVELS: 3
|
37 |
+
TOTAL_NUM_FEATURE_LEVELS: 4
|
38 |
+
IN_FEATURES: ["res2", "res3", "res4", "res5"]
|
39 |
+
DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES: ["res3", "res4", "res5"]
|
40 |
+
COMMON_STRIDE: 4
|
41 |
+
TRANSFORMER_ENC_LAYERS: 6
|
42 |
+
FEATURE_ORDER: "low2high"
|
43 |
+
MaskDINO:
|
44 |
+
TRANSFORMER_DECODER_NAME: "MaskDINODecoder"
|
45 |
+
DEEP_SUPERVISION: True
|
46 |
+
NO_OBJECT_WEIGHT: 0.1
|
47 |
+
CLASS_WEIGHT: 4.0
|
48 |
+
MASK_WEIGHT: 5.0
|
49 |
+
DICE_WEIGHT: 5.0
|
50 |
+
BOX_WEIGHT: 5.0
|
51 |
+
GIOU_WEIGHT: 2.0
|
52 |
+
HIDDEN_DIM: 256
|
53 |
+
NUM_OBJECT_QUERIES: 300
|
54 |
+
NHEADS: 8
|
55 |
+
DROPOUT: 0.0
|
56 |
+
DIM_FEEDFORWARD: 2048
|
57 |
+
ENC_LAYERS: 0
|
58 |
+
PRE_NORM: False
|
59 |
+
ENFORCE_INPUT_PROJ: False
|
60 |
+
SIZE_DIVISIBILITY: 32
|
61 |
+
DEC_LAYERS: 9 # 9+1, 9 decoder layers, add one for the loss on learnable query
|
62 |
+
TRAIN_NUM_POINTS: 12544
|
63 |
+
OVERSAMPLE_RATIO: 3.0
|
64 |
+
IMPORTANCE_SAMPLE_RATIO: 0.75
|
65 |
+
INITIAL_PRED: True
|
66 |
+
TWO_STAGE: True
|
67 |
+
DN: "standard"
|
68 |
+
DN_NUM: 100
|
69 |
+
INITIALIZE_BOX_TYPE: "no"
|
70 |
+
TEST:
|
71 |
+
SEMANTIC_ON: False
|
72 |
+
INSTANCE_ON: True
|
73 |
+
PANOPTIC_ON: False
|
74 |
+
OVERLAP_THRESHOLD: 0.8
|
75 |
+
OBJECT_MASK_THRESHOLD: 0.25
|
76 |
+
TEXT:
|
77 |
+
ARCH: clip_teacher
|
78 |
+
LANGUAGE_BACKBONE:
|
79 |
+
LANG_DIM: 512
|
GLEE/glee/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
GLEE/glee/__init__.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import
|
2 |
+
from __future__ import division
|
3 |
+
from __future__ import print_function
|
4 |
+
|
5 |
+
|
6 |
+
from .config import add_glee_config
|
7 |
+
from .config_deeplab import add_deeplab_config
|
8 |
+
# from .GLEE import GLEE
|
9 |
+
# from .data import build_detection_train_loader, build_detection_test_loader
|
10 |
+
from .backbone.swin import D2SwinTransformer
|
11 |
+
from .backbone.eva02 import D2_EVA02
|
12 |
+
|
GLEE/glee/backbone/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .build import build_backbone
|
2 |
+
|
3 |
+
from .resnet import *
|
4 |
+
from .swin import *
|
5 |
+
# from .focal import *
|
6 |
+
# from .focal_dw import *
|
7 |
+
from .backbone import *
|
GLEE/glee/backbone/backbone.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from detectron2.modeling import ShapeSpec
|
5 |
+
|
6 |
+
__all__ = ["Backbone"]
|
7 |
+
|
8 |
+
|
9 |
+
class Backbone(nn.Module):
|
10 |
+
"""
|
11 |
+
Abstract base class for network backbones.
|
12 |
+
"""
|
13 |
+
|
14 |
+
def __init__(self):
|
15 |
+
"""
|
16 |
+
The `__init__` method of any subclass can specify its own set of arguments.
|
17 |
+
"""
|
18 |
+
super().__init__()
|
19 |
+
|
20 |
+
def forward(self):
|
21 |
+
"""
|
22 |
+
Subclasses must override this method, but adhere to the same return type.
|
23 |
+
|
24 |
+
Returns:
|
25 |
+
dict[str->Tensor]: mapping from feature name (e.g., "res2") to tensor
|
26 |
+
"""
|
27 |
+
pass
|
28 |
+
|
29 |
+
@property
|
30 |
+
def size_divisibility(self) -> int:
|
31 |
+
"""
|
32 |
+
Some backbones require the input height and width to be divisible by a
|
33 |
+
specific integer. This is typically true for encoder / decoder type networks
|
34 |
+
with lateral connection (e.g., FPN) for which feature maps need to match
|
35 |
+
dimension in the "bottom up" and "top down" paths. Set to 0 if no specific
|
36 |
+
input size divisibility is required.
|
37 |
+
"""
|
38 |
+
return 0
|
39 |
+
|
40 |
+
def output_shape(self):
|
41 |
+
"""
|
42 |
+
Returns:
|
43 |
+
dict[str->ShapeSpec]
|
44 |
+
"""
|
45 |
+
# this is a backward-compatible default
|
46 |
+
return {
|
47 |
+
name: ShapeSpec(
|
48 |
+
channels=self._out_feature_channels[name], stride=self._out_feature_strides[name]
|
49 |
+
)
|
50 |
+
for name in self._out_features
|
51 |
+
}
|
GLEE/glee/backbone/build.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .registry import model_entrypoints
|
2 |
+
from .registry import is_model
|
3 |
+
|
4 |
+
from .backbone import *
|
5 |
+
|
6 |
+
def build_backbone(config, **kwargs):
|
7 |
+
model_name = config['MODEL']['BACKBONE']['NAME']
|
8 |
+
if not is_model(model_name):
|
9 |
+
raise ValueError(f'Unkown model: {model_name}')
|
10 |
+
model = model_entrypoints(model_name)(config, **kwargs)
|
11 |
+
return model
|
GLEE/glee/backbone/davit.py
ADDED
@@ -0,0 +1,623 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import itertools
|
3 |
+
import logging
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
import torch.utils.checkpoint as checkpoint
|
9 |
+
from collections import OrderedDict
|
10 |
+
|
11 |
+
from einops import rearrange
|
12 |
+
from timm.models.layers import DropPath, trunc_normal_
|
13 |
+
|
14 |
+
from detectron2.utils.file_io import PathManager
|
15 |
+
from detectron2.modeling import BACKBONE_REGISTRY, Backbone, ShapeSpec
|
16 |
+
|
17 |
+
from .registry import register_backbone
|
18 |
+
|
19 |
+
logger = logging.getLogger(__name__)
|
20 |
+
|
21 |
+
|
22 |
+
class MySequential(nn.Sequential):
|
23 |
+
def forward(self, *inputs):
|
24 |
+
for module in self._modules.values():
|
25 |
+
if type(inputs) == tuple:
|
26 |
+
inputs = module(*inputs)
|
27 |
+
else:
|
28 |
+
inputs = module(inputs)
|
29 |
+
return inputs
|
30 |
+
|
31 |
+
|
32 |
+
class PreNorm(nn.Module):
|
33 |
+
def __init__(self, norm, fn, drop_path=None):
|
34 |
+
super().__init__()
|
35 |
+
self.norm = norm
|
36 |
+
self.fn = fn
|
37 |
+
self.drop_path = drop_path
|
38 |
+
|
39 |
+
def forward(self, x, *args, **kwargs):
|
40 |
+
shortcut = x
|
41 |
+
if self.norm != None:
|
42 |
+
x, size = self.fn(self.norm(x), *args, **kwargs)
|
43 |
+
else:
|
44 |
+
x, size = self.fn(x, *args, **kwargs)
|
45 |
+
|
46 |
+
if self.drop_path:
|
47 |
+
x = self.drop_path(x)
|
48 |
+
|
49 |
+
x = shortcut + x
|
50 |
+
|
51 |
+
return x, size
|
52 |
+
|
53 |
+
|
54 |
+
class Mlp(nn.Module):
|
55 |
+
def __init__(
|
56 |
+
self,
|
57 |
+
in_features,
|
58 |
+
hidden_features=None,
|
59 |
+
out_features=None,
|
60 |
+
act_layer=nn.GELU,
|
61 |
+
):
|
62 |
+
super().__init__()
|
63 |
+
out_features = out_features or in_features
|
64 |
+
hidden_features = hidden_features or in_features
|
65 |
+
self.net = nn.Sequential(OrderedDict([
|
66 |
+
("fc1", nn.Linear(in_features, hidden_features)),
|
67 |
+
("act", act_layer()),
|
68 |
+
("fc2", nn.Linear(hidden_features, out_features))
|
69 |
+
]))
|
70 |
+
|
71 |
+
def forward(self, x, size):
|
72 |
+
return self.net(x), size
|
73 |
+
|
74 |
+
|
75 |
+
class DepthWiseConv2d(nn.Module):
|
76 |
+
def __init__(
|
77 |
+
self,
|
78 |
+
dim_in,
|
79 |
+
kernel_size,
|
80 |
+
padding,
|
81 |
+
stride,
|
82 |
+
bias=True,
|
83 |
+
):
|
84 |
+
super().__init__()
|
85 |
+
self.dw = nn.Conv2d(
|
86 |
+
dim_in, dim_in,
|
87 |
+
kernel_size=kernel_size,
|
88 |
+
padding=padding,
|
89 |
+
groups=dim_in,
|
90 |
+
stride=stride,
|
91 |
+
bias=bias
|
92 |
+
)
|
93 |
+
|
94 |
+
def forward(self, x, size):
|
95 |
+
B, N, C = x.shape
|
96 |
+
H, W = size
|
97 |
+
assert N == H * W
|
98 |
+
|
99 |
+
x = self.dw(x.transpose(1, 2).view(B, C, H, W))
|
100 |
+
size = (x.size(-2), x.size(-1))
|
101 |
+
x = x.flatten(2).transpose(1, 2)
|
102 |
+
return x, size
|
103 |
+
|
104 |
+
|
105 |
+
class ConvEmbed(nn.Module):
|
106 |
+
""" Image to Patch Embedding
|
107 |
+
"""
|
108 |
+
|
109 |
+
def __init__(
|
110 |
+
self,
|
111 |
+
patch_size=7,
|
112 |
+
in_chans=3,
|
113 |
+
embed_dim=64,
|
114 |
+
stride=4,
|
115 |
+
padding=2,
|
116 |
+
norm_layer=None,
|
117 |
+
pre_norm=True
|
118 |
+
):
|
119 |
+
super().__init__()
|
120 |
+
self.patch_size = patch_size
|
121 |
+
|
122 |
+
self.proj = nn.Conv2d(
|
123 |
+
in_chans, embed_dim,
|
124 |
+
kernel_size=patch_size,
|
125 |
+
stride=stride,
|
126 |
+
padding=padding
|
127 |
+
)
|
128 |
+
|
129 |
+
dim_norm = in_chans if pre_norm else embed_dim
|
130 |
+
self.norm = norm_layer(dim_norm) if norm_layer else None
|
131 |
+
|
132 |
+
self.pre_norm = pre_norm
|
133 |
+
|
134 |
+
def forward(self, x, size):
|
135 |
+
H, W = size
|
136 |
+
if len(x.size()) == 3:
|
137 |
+
if self.norm and self.pre_norm:
|
138 |
+
x = self.norm(x)
|
139 |
+
x = rearrange(
|
140 |
+
x, 'b (h w) c -> b c h w',
|
141 |
+
h=H, w=W
|
142 |
+
)
|
143 |
+
|
144 |
+
x = self.proj(x)
|
145 |
+
|
146 |
+
_, _, H, W = x.shape
|
147 |
+
x = rearrange(x, 'b c h w -> b (h w) c')
|
148 |
+
if self.norm and not self.pre_norm:
|
149 |
+
x = self.norm(x)
|
150 |
+
|
151 |
+
return x, (H, W)
|
152 |
+
|
153 |
+
|
154 |
+
class ChannelAttention(nn.Module):
|
155 |
+
|
156 |
+
def __init__(self, dim, groups=8, qkv_bias=True):
|
157 |
+
super().__init__()
|
158 |
+
|
159 |
+
self.groups = groups
|
160 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
161 |
+
self.proj = nn.Linear(dim, dim)
|
162 |
+
|
163 |
+
def forward(self, x, size):
|
164 |
+
B, N, C = x.shape
|
165 |
+
|
166 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.groups, C // self.groups).permute(2, 0, 3, 1, 4)
|
167 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
168 |
+
|
169 |
+
q = q * (N ** -0.5)
|
170 |
+
attention = q.transpose(-1, -2) @ k
|
171 |
+
attention = attention.softmax(dim=-1)
|
172 |
+
x = (attention @ v.transpose(-1, -2)).transpose(-1, -2)
|
173 |
+
x = x.transpose(1, 2).reshape(B, N, C)
|
174 |
+
x = self.proj(x)
|
175 |
+
return x, size
|
176 |
+
|
177 |
+
|
178 |
+
class ChannelBlock(nn.Module):
|
179 |
+
|
180 |
+
def __init__(self, dim, groups, mlp_ratio=4., qkv_bias=True,
|
181 |
+
drop_path_rate=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm,
|
182 |
+
conv_at_attn=True, conv_at_ffn=True):
|
183 |
+
super().__init__()
|
184 |
+
|
185 |
+
drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
|
186 |
+
|
187 |
+
self.conv1 = PreNorm(None, DepthWiseConv2d(dim, 3, 1, 1)) if conv_at_attn else None
|
188 |
+
self.channel_attn = PreNorm(
|
189 |
+
norm_layer(dim),
|
190 |
+
ChannelAttention(dim, groups=groups, qkv_bias=qkv_bias),
|
191 |
+
drop_path
|
192 |
+
)
|
193 |
+
self.conv2 = PreNorm(None, DepthWiseConv2d(dim, 3, 1, 1)) if conv_at_ffn else None
|
194 |
+
self.ffn = PreNorm(
|
195 |
+
norm_layer(dim),
|
196 |
+
Mlp(in_features=dim, hidden_features=int(dim*mlp_ratio), act_layer=act_layer),
|
197 |
+
drop_path
|
198 |
+
)
|
199 |
+
|
200 |
+
def forward(self, x, size):
|
201 |
+
if self.conv1:
|
202 |
+
x, size = self.conv1(x, size)
|
203 |
+
x, size = self.channel_attn(x, size)
|
204 |
+
|
205 |
+
if self.conv2:
|
206 |
+
x, size = self.conv2(x, size)
|
207 |
+
x, size = self.ffn(x, size)
|
208 |
+
|
209 |
+
return x, size
|
210 |
+
|
211 |
+
|
212 |
+
def window_partition(x, window_size: int):
|
213 |
+
B, H, W, C = x.shape
|
214 |
+
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
|
215 |
+
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
216 |
+
return windows
|
217 |
+
|
218 |
+
|
219 |
+
def window_reverse(windows, window_size: int, H: int, W: int):
|
220 |
+
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
221 |
+
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
|
222 |
+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
223 |
+
return x
|
224 |
+
|
225 |
+
|
226 |
+
class WindowAttention(nn.Module):
|
227 |
+
def __init__(self, dim, num_heads, window_size, qkv_bias=True):
|
228 |
+
|
229 |
+
super().__init__()
|
230 |
+
self.dim = dim
|
231 |
+
self.window_size = window_size
|
232 |
+
self.num_heads = num_heads
|
233 |
+
head_dim = dim // num_heads
|
234 |
+
self.scale = head_dim ** -0.5
|
235 |
+
|
236 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
237 |
+
self.proj = nn.Linear(dim, dim)
|
238 |
+
|
239 |
+
self.softmax = nn.Softmax(dim=-1)
|
240 |
+
|
241 |
+
def forward(self, x, size):
|
242 |
+
|
243 |
+
H, W = size
|
244 |
+
B, L, C = x.shape
|
245 |
+
assert L == H * W, "input feature has wrong size"
|
246 |
+
|
247 |
+
x = x.view(B, H, W, C)
|
248 |
+
|
249 |
+
pad_l = pad_t = 0
|
250 |
+
pad_r = (self.window_size - W % self.window_size) % self.window_size
|
251 |
+
pad_b = (self.window_size - H % self.window_size) % self.window_size
|
252 |
+
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
|
253 |
+
_, Hp, Wp, _ = x.shape
|
254 |
+
|
255 |
+
x = window_partition(x, self.window_size)
|
256 |
+
x = x.view(-1, self.window_size * self.window_size, C)
|
257 |
+
|
258 |
+
# W-MSA/SW-MSA
|
259 |
+
# attn_windows = self.attn(x_windows)
|
260 |
+
|
261 |
+
B_, N, C = x.shape
|
262 |
+
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
263 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
264 |
+
|
265 |
+
q = q * self.scale
|
266 |
+
attn = (q @ k.transpose(-2, -1))
|
267 |
+
attn = self.softmax(attn)
|
268 |
+
|
269 |
+
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
|
270 |
+
x = self.proj(x)
|
271 |
+
|
272 |
+
# merge windows
|
273 |
+
x = x.view(
|
274 |
+
-1, self.window_size, self.window_size, C
|
275 |
+
)
|
276 |
+
x = window_reverse(x, self.window_size, Hp, Wp)
|
277 |
+
|
278 |
+
if pad_r > 0 or pad_b > 0:
|
279 |
+
x = x[:, :H, :W, :].contiguous()
|
280 |
+
|
281 |
+
x = x.view(B, H * W, C)
|
282 |
+
|
283 |
+
return x, size
|
284 |
+
|
285 |
+
|
286 |
+
class SpatialBlock(nn.Module):
|
287 |
+
|
288 |
+
def __init__(self, dim, num_heads, window_size,
|
289 |
+
mlp_ratio=4., qkv_bias=True, drop_path_rate=0., act_layer=nn.GELU,
|
290 |
+
norm_layer=nn.LayerNorm, conv_at_attn=True, conv_at_ffn=True):
|
291 |
+
super().__init__()
|
292 |
+
|
293 |
+
drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
|
294 |
+
|
295 |
+
self.conv1 = PreNorm(None, DepthWiseConv2d(dim, 3, 1, 1)) if conv_at_attn else None
|
296 |
+
self.window_attn = PreNorm(
|
297 |
+
norm_layer(dim),
|
298 |
+
WindowAttention(dim, num_heads, window_size, qkv_bias=qkv_bias),
|
299 |
+
drop_path
|
300 |
+
)
|
301 |
+
self.conv2 = PreNorm(None, DepthWiseConv2d(dim, 3, 1, 1)) if conv_at_ffn else None
|
302 |
+
self.ffn = PreNorm(
|
303 |
+
norm_layer(dim),
|
304 |
+
Mlp(in_features=dim, hidden_features=int(dim*mlp_ratio), act_layer=act_layer),
|
305 |
+
drop_path
|
306 |
+
)
|
307 |
+
|
308 |
+
def forward(self, x, size):
|
309 |
+
if self.conv1:
|
310 |
+
x, size = self.conv1(x, size)
|
311 |
+
x, size = self.window_attn(x, size)
|
312 |
+
|
313 |
+
if self.conv2:
|
314 |
+
x, size = self.conv2(x, size)
|
315 |
+
x, size = self.ffn(x, size)
|
316 |
+
return x, size
|
317 |
+
|
318 |
+
|
319 |
+
class DaViT(nn.Module):
|
320 |
+
""" DaViT: Dual-Attention Transformer
|
321 |
+
|
322 |
+
Args:
|
323 |
+
img_size (int): Image size, Default: 224.
|
324 |
+
in_chans (int): Number of input image channels. Default: 3.
|
325 |
+
num_classes (int): Number of classes for classification head. Default: 1000.
|
326 |
+
patch_size (tuple(int)): Patch size of convolution in different stages. Default: (7, 2, 2, 2).
|
327 |
+
patch_stride (tuple(int)): Patch stride of convolution in different stages. Default: (4, 2, 2, 2).
|
328 |
+
patch_padding (tuple(int)): Patch padding of convolution in different stages. Default: (3, 0, 0, 0).
|
329 |
+
patch_prenorm (tuple(bool)): If True, perform norm before convlution layer. Default: (True, False, False, False).
|
330 |
+
embed_dims (tuple(int)): Patch embedding dimension in different stages. Default: (64, 128, 192, 256).
|
331 |
+
num_heads (tuple(int)): Number of spatial attention heads in different stages. Default: (4, 8, 12, 16).
|
332 |
+
num_groups (tuple(int)): Number of channel groups in different stages. Default: (4, 8, 12, 16).
|
333 |
+
window_size (int): Window size. Default: 7.
|
334 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
|
335 |
+
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True.
|
336 |
+
drop_path_rate (float): Stochastic depth rate. Default: 0.1.
|
337 |
+
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
|
338 |
+
enable_checkpoint (bool): If True, enable checkpointing. Default: False.
|
339 |
+
conv_at_attn (bool): If True, performe depthwise convolution before attention layer. Default: True.
|
340 |
+
conv_at_ffn (bool): If True, performe depthwise convolution before ffn layer. Default: True.
|
341 |
+
"""
|
342 |
+
|
343 |
+
def __init__(
|
344 |
+
self,
|
345 |
+
img_size=224,
|
346 |
+
in_chans=3,
|
347 |
+
num_classes=1000,
|
348 |
+
depths=(1, 1, 3, 1),
|
349 |
+
patch_size=(7, 2, 2, 2),
|
350 |
+
patch_stride=(4, 2, 2, 2),
|
351 |
+
patch_padding=(3, 0, 0, 0),
|
352 |
+
patch_prenorm=(False, False, False, False),
|
353 |
+
embed_dims=(64, 128, 192, 256),
|
354 |
+
num_heads=(3, 6, 12, 24),
|
355 |
+
num_groups=(3, 6, 12, 24),
|
356 |
+
window_size=7,
|
357 |
+
mlp_ratio=4.,
|
358 |
+
qkv_bias=True,
|
359 |
+
drop_path_rate=0.1,
|
360 |
+
norm_layer=nn.LayerNorm,
|
361 |
+
enable_checkpoint=False,
|
362 |
+
conv_at_attn=True,
|
363 |
+
conv_at_ffn=True,
|
364 |
+
out_indices=[],
|
365 |
+
):
|
366 |
+
super().__init__()
|
367 |
+
|
368 |
+
self.num_classes = num_classes
|
369 |
+
self.embed_dims = embed_dims
|
370 |
+
self.num_heads = num_heads
|
371 |
+
self.num_groups = num_groups
|
372 |
+
self.num_stages = len(self.embed_dims)
|
373 |
+
self.enable_checkpoint = enable_checkpoint
|
374 |
+
assert self.num_stages == len(self.num_heads) == len(self.num_groups)
|
375 |
+
|
376 |
+
num_stages = len(embed_dims)
|
377 |
+
self.img_size = img_size
|
378 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)*2)]
|
379 |
+
|
380 |
+
|
381 |
+
depth_offset = 0
|
382 |
+
convs = []
|
383 |
+
blocks = []
|
384 |
+
for i in range(num_stages):
|
385 |
+
conv_embed = ConvEmbed(
|
386 |
+
patch_size=patch_size[i],
|
387 |
+
stride=patch_stride[i],
|
388 |
+
padding=patch_padding[i],
|
389 |
+
in_chans=in_chans if i == 0 else self.embed_dims[i - 1],
|
390 |
+
embed_dim=self.embed_dims[i],
|
391 |
+
norm_layer=norm_layer,
|
392 |
+
pre_norm=patch_prenorm[i]
|
393 |
+
)
|
394 |
+
convs.append(conv_embed)
|
395 |
+
|
396 |
+
print(f'=> Depth offset in stage {i}: {depth_offset}')
|
397 |
+
block = MySequential(
|
398 |
+
*[
|
399 |
+
MySequential(OrderedDict([
|
400 |
+
(
|
401 |
+
'spatial_block', SpatialBlock(
|
402 |
+
embed_dims[i],
|
403 |
+
num_heads[i],
|
404 |
+
window_size,
|
405 |
+
drop_path_rate=dpr[depth_offset+j*2],
|
406 |
+
qkv_bias=qkv_bias,
|
407 |
+
mlp_ratio=mlp_ratio,
|
408 |
+
conv_at_attn=conv_at_attn,
|
409 |
+
conv_at_ffn=conv_at_ffn,
|
410 |
+
)
|
411 |
+
),
|
412 |
+
(
|
413 |
+
'channel_block', ChannelBlock(
|
414 |
+
embed_dims[i],
|
415 |
+
num_groups[i],
|
416 |
+
drop_path_rate=dpr[depth_offset+j*2+1],
|
417 |
+
qkv_bias=qkv_bias,
|
418 |
+
mlp_ratio=mlp_ratio,
|
419 |
+
conv_at_attn=conv_at_attn,
|
420 |
+
conv_at_ffn=conv_at_ffn,
|
421 |
+
)
|
422 |
+
)
|
423 |
+
])) for j in range(depths[i])
|
424 |
+
]
|
425 |
+
)
|
426 |
+
blocks.append(block)
|
427 |
+
depth_offset += depths[i]*2
|
428 |
+
|
429 |
+
self.convs = nn.ModuleList(convs)
|
430 |
+
self.blocks = nn.ModuleList(blocks)
|
431 |
+
|
432 |
+
self.out_indices = out_indices
|
433 |
+
# self.norms = norm_layer(self.embed_dims[-1])
|
434 |
+
# self.avgpool = nn.AdaptiveAvgPool1d(1)
|
435 |
+
# self.head = nn.Linear(self.embed_dims[-1], num_classes) if num_classes > 0 else nn.Identity()
|
436 |
+
self.apply(self._init_weights)
|
437 |
+
|
438 |
+
@property
|
439 |
+
def dim_out(self):
|
440 |
+
return self.embed_dims[-1]
|
441 |
+
|
442 |
+
def _init_weights(self, m):
|
443 |
+
if isinstance(m, nn.Linear):
|
444 |
+
trunc_normal_(m.weight, std=0.02)
|
445 |
+
if m.bias is not None:
|
446 |
+
nn.init.constant_(m.bias, 0)
|
447 |
+
elif isinstance(m, nn.Conv2d):
|
448 |
+
nn.init.normal_(m.weight, std=0.02)
|
449 |
+
for name, _ in m.named_parameters():
|
450 |
+
if name in ['bias']:
|
451 |
+
nn.init.constant_(m.bias, 0)
|
452 |
+
elif isinstance(m, nn.LayerNorm):
|
453 |
+
nn.init.constant_(m.weight, 1.0)
|
454 |
+
nn.init.constant_(m.bias, 0)
|
455 |
+
elif isinstance(m, nn.BatchNorm2d):
|
456 |
+
nn.init.constant_(m.weight, 1.0)
|
457 |
+
nn.init.constant_(m.bias, 0)
|
458 |
+
|
459 |
+
def _try_remap_keys(self, pretrained_dict):
|
460 |
+
remap_keys = {
|
461 |
+
"conv_embeds": "convs",
|
462 |
+
"main_blocks": "blocks",
|
463 |
+
"0.cpe.0.proj": "spatial_block.conv1.fn.dw",
|
464 |
+
"0.attn": "spatial_block.window_attn.fn",
|
465 |
+
"0.cpe.1.proj": "spatial_block.conv2.fn.dw",
|
466 |
+
"0.mlp": "spatial_block.ffn.fn.net",
|
467 |
+
"1.cpe.0.proj": "channel_block.conv1.fn.dw",
|
468 |
+
"1.attn": "channel_block.channel_attn.fn",
|
469 |
+
"1.cpe.1.proj": "channel_block.conv2.fn.dw",
|
470 |
+
"1.mlp": "channel_block.ffn.fn.net",
|
471 |
+
"0.norm1": "spatial_block.window_attn.norm",
|
472 |
+
"0.norm2": "spatial_block.ffn.norm",
|
473 |
+
"1.norm1": "channel_block.channel_attn.norm",
|
474 |
+
"1.norm2": "channel_block.ffn.norm"
|
475 |
+
}
|
476 |
+
|
477 |
+
full_key_mappings = {}
|
478 |
+
for k in pretrained_dict.keys():
|
479 |
+
old_k = k
|
480 |
+
for remap_key in remap_keys.keys():
|
481 |
+
if remap_key in k:
|
482 |
+
print(f'=> Repace {remap_key} with {remap_keys[remap_key]}')
|
483 |
+
k = k.replace(remap_key, remap_keys[remap_key])
|
484 |
+
|
485 |
+
full_key_mappings[old_k] = k
|
486 |
+
|
487 |
+
return full_key_mappings
|
488 |
+
|
489 |
+
def from_state_dict(self, pretrained_dict, pretrained_layers=[], verbose=True):
|
490 |
+
model_dict = self.state_dict()
|
491 |
+
stripped_key = lambda x: x[14:] if x.startswith('image_encoder.') else x
|
492 |
+
full_key_mappings = self._try_remap_keys(pretrained_dict)
|
493 |
+
|
494 |
+
pretrained_dict = {
|
495 |
+
stripped_key(full_key_mappings[k]): v for k, v in pretrained_dict.items()
|
496 |
+
if stripped_key(full_key_mappings[k]) in model_dict.keys()
|
497 |
+
}
|
498 |
+
need_init_state_dict = {}
|
499 |
+
for k, v in pretrained_dict.items():
|
500 |
+
need_init = (
|
501 |
+
k.split('.')[0] in pretrained_layers
|
502 |
+
or pretrained_layers[0] == '*'
|
503 |
+
)
|
504 |
+
if need_init:
|
505 |
+
if verbose:
|
506 |
+
print(f'=> init {k} from pretrained state dict')
|
507 |
+
|
508 |
+
need_init_state_dict[k] = v
|
509 |
+
self.load_state_dict(need_init_state_dict, strict=False)
|
510 |
+
|
511 |
+
def from_pretrained(self, pretrained='', pretrained_layers=[], verbose=True):
|
512 |
+
if os.path.isfile(pretrained):
|
513 |
+
print(f'=> loading pretrained model {pretrained}')
|
514 |
+
pretrained_dict = torch.load(pretrained, map_location='cpu')
|
515 |
+
|
516 |
+
self.from_state_dict(pretrained_dict, pretrained_layers, verbose)
|
517 |
+
|
518 |
+
def forward_features(self, x):
|
519 |
+
input_size = (x.size(2), x.size(3))
|
520 |
+
|
521 |
+
outs = {}
|
522 |
+
for i, (conv, block) in enumerate(zip(self.convs, self.blocks)):
|
523 |
+
x, input_size = conv(x, input_size)
|
524 |
+
if self.enable_checkpoint:
|
525 |
+
x, input_size = checkpoint.checkpoint(block, x, input_size)
|
526 |
+
else:
|
527 |
+
x, input_size = block(x, input_size)
|
528 |
+
if i in self.out_indices:
|
529 |
+
out = x.view(-1, *input_size, self.embed_dims[i]).permute(0, 3, 1, 2).contiguous()
|
530 |
+
outs["res{}".format(i + 2)] = out
|
531 |
+
|
532 |
+
if len(self.out_indices) == 0:
|
533 |
+
outs["res5"] = x.view(-1, *input_size, self.embed_dims[-1]).permute(0, 3, 1, 2).contiguous()
|
534 |
+
|
535 |
+
return outs
|
536 |
+
|
537 |
+
def forward(self, x):
|
538 |
+
x = self.forward_features(x)
|
539 |
+
# x = self.head(x)
|
540 |
+
return x
|
541 |
+
|
542 |
+
class D2DaViT(DaViT, Backbone):
|
543 |
+
def __init__(self, cfg, input_shape):
|
544 |
+
|
545 |
+
spec = cfg['BACKBONE']['DAVIT']
|
546 |
+
|
547 |
+
super().__init__(
|
548 |
+
num_classes=0,
|
549 |
+
depths=spec['DEPTHS'],
|
550 |
+
embed_dims=spec['DIM_EMBED'],
|
551 |
+
num_heads=spec['NUM_HEADS'],
|
552 |
+
num_groups=spec['NUM_GROUPS'],
|
553 |
+
patch_size=spec['PATCH_SIZE'],
|
554 |
+
patch_stride=spec['PATCH_STRIDE'],
|
555 |
+
patch_padding=spec['PATCH_PADDING'],
|
556 |
+
patch_prenorm=spec['PATCH_PRENORM'],
|
557 |
+
drop_path_rate=spec['DROP_PATH_RATE'],
|
558 |
+
img_size=input_shape,
|
559 |
+
window_size=spec.get('WINDOW_SIZE', 7),
|
560 |
+
enable_checkpoint=spec.get('ENABLE_CHECKPOINT', False),
|
561 |
+
conv_at_attn=spec.get('CONV_AT_ATTN', True),
|
562 |
+
conv_at_ffn=spec.get('CONV_AT_FFN', True),
|
563 |
+
out_indices=spec.get('OUT_INDICES', []),
|
564 |
+
)
|
565 |
+
|
566 |
+
self._out_features = cfg['BACKBONE']['DAVIT']['OUT_FEATURES']
|
567 |
+
|
568 |
+
self._out_feature_strides = {
|
569 |
+
"res2": 4,
|
570 |
+
"res3": 8,
|
571 |
+
"res4": 16,
|
572 |
+
"res5": 32,
|
573 |
+
}
|
574 |
+
self._out_feature_channels = {
|
575 |
+
"res2": self.embed_dims[0],
|
576 |
+
"res3": self.embed_dims[1],
|
577 |
+
"res4": self.embed_dims[2],
|
578 |
+
"res5": self.embed_dims[3],
|
579 |
+
}
|
580 |
+
|
581 |
+
def forward(self, x):
|
582 |
+
"""
|
583 |
+
Args:
|
584 |
+
x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``.
|
585 |
+
Returns:
|
586 |
+
dict[str->Tensor]: names and the corresponding features
|
587 |
+
"""
|
588 |
+
assert (
|
589 |
+
x.dim() == 4
|
590 |
+
), f"SwinTransformer takes an input of shape (N, C, H, W). Got {x.shape} instead!"
|
591 |
+
outputs = {}
|
592 |
+
y = super().forward(x)
|
593 |
+
|
594 |
+
for k in y.keys():
|
595 |
+
if k in self._out_features:
|
596 |
+
outputs[k] = y[k]
|
597 |
+
return outputs
|
598 |
+
|
599 |
+
def output_shape(self):
|
600 |
+
return {
|
601 |
+
name: ShapeSpec(
|
602 |
+
channels=self._out_feature_channels[name], stride=self._out_feature_strides[name]
|
603 |
+
)
|
604 |
+
for name in self._out_features
|
605 |
+
}
|
606 |
+
|
607 |
+
@property
|
608 |
+
def size_divisibility(self):
|
609 |
+
return 32
|
610 |
+
|
611 |
+
@register_backbone
|
612 |
+
def get_davit_backbone(cfg):
|
613 |
+
davit = D2DaViT(cfg['MODEL'], 224)
|
614 |
+
|
615 |
+
if cfg['MODEL']['BACKBONE']['LOAD_PRETRAINED'] is True:
|
616 |
+
filename = cfg['MODEL']['BACKBONE']['PRETRAINED']
|
617 |
+
logger.info(f'=> init from {filename}')
|
618 |
+
davit.from_pretrained(
|
619 |
+
filename,
|
620 |
+
cfg['MODEL']['BACKBONE']['DAVIT'].get('PRETRAINED_LAYERS', ['*']),
|
621 |
+
cfg['VERBOSE'])
|
622 |
+
|
623 |
+
return davit
|
GLEE/glee/backbone/eva01.py
ADDED
@@ -0,0 +1,676 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import math
|
3 |
+
from functools import partial
|
4 |
+
|
5 |
+
import fvcore.nn.weight_init as weight_init
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from torch import Tensor, Size
|
10 |
+
from typing import Union, List
|
11 |
+
from torch.nn.parameter import Parameter
|
12 |
+
import numbers
|
13 |
+
|
14 |
+
from detectron2.layers import CNNBlockBase, Conv2d, get_norm
|
15 |
+
from detectron2.modeling.backbone.fpn import _assert_strides_are_log2_contiguous
|
16 |
+
|
17 |
+
from fairscale.nn.checkpoint import checkpoint_wrapper
|
18 |
+
from timm.models.layers import DropPath, Mlp, trunc_normal_
|
19 |
+
|
20 |
+
# from detectron2.modeling.backbone import Backbone
|
21 |
+
from detectron2.modeling import BACKBONE_REGISTRY, Backbone, ShapeSpec
|
22 |
+
|
23 |
+
from .eva_01_utils import (
|
24 |
+
PatchEmbed,
|
25 |
+
add_decomposed_rel_pos,
|
26 |
+
get_abs_pos,
|
27 |
+
window_partition,
|
28 |
+
window_unpartition,
|
29 |
+
)
|
30 |
+
from detectron2.modeling.backbone.fpn import LastLevelMaxPool
|
31 |
+
|
32 |
+
logger = logging.getLogger(__name__)
|
33 |
+
|
34 |
+
|
35 |
+
__all__ = ["EVAViT", "SimpleFeaturePyramid", "get_vit_lr_decay_rate"]
|
36 |
+
|
37 |
+
|
38 |
+
_shape_t = Union[int, List[int], Size]
|
39 |
+
|
40 |
+
|
41 |
+
# steal from beit https://github.com/microsoft/unilm/tree/master/beit
|
42 |
+
class LayerNormWithForceFP32(nn.Module):
|
43 |
+
__constants__ = ['normalized_shape', 'eps', 'elementwise_affine']
|
44 |
+
normalized_shape: _shape_t
|
45 |
+
eps: float
|
46 |
+
elementwise_affine: bool
|
47 |
+
|
48 |
+
def __init__(self, normalized_shape: _shape_t, eps: float = 1e-5, elementwise_affine: bool = True) -> None:
|
49 |
+
super(LayerNormWithForceFP32, self).__init__()
|
50 |
+
if isinstance(normalized_shape, numbers.Integral):
|
51 |
+
normalized_shape = (normalized_shape,)
|
52 |
+
self.normalized_shape = tuple(normalized_shape)
|
53 |
+
self.eps = eps
|
54 |
+
self.elementwise_affine = elementwise_affine
|
55 |
+
if self.elementwise_affine:
|
56 |
+
self.weight = Parameter(torch.Tensor(*normalized_shape))
|
57 |
+
self.bias = Parameter(torch.Tensor(*normalized_shape))
|
58 |
+
else:
|
59 |
+
self.register_parameter('weight', None)
|
60 |
+
self.register_parameter('bias', None)
|
61 |
+
self.reset_parameters()
|
62 |
+
|
63 |
+
def reset_parameters(self) -> None:
|
64 |
+
if self.elementwise_affine:
|
65 |
+
nn.init.ones_(self.weight)
|
66 |
+
nn.init.zeros_(self.bias)
|
67 |
+
|
68 |
+
def forward(self, input: Tensor) -> Tensor:
|
69 |
+
return F.layer_norm(
|
70 |
+
input.float(), self.normalized_shape, self.weight.float(), self.bias.float(), self.eps).type_as(input)
|
71 |
+
|
72 |
+
def extra_repr(self) -> Tensor:
|
73 |
+
return '{normalized_shape}, eps={eps}, ' \
|
74 |
+
'elementwise_affine={elementwise_affine}'.format(**self.__dict__)
|
75 |
+
|
76 |
+
|
77 |
+
class Attention(nn.Module):
|
78 |
+
"""Multi-head Attention block with relative position embeddings."""
|
79 |
+
|
80 |
+
def __init__(
|
81 |
+
self,
|
82 |
+
dim,
|
83 |
+
num_heads=8,
|
84 |
+
qkv_bias=True,
|
85 |
+
beit_like_qkv_bias=False,
|
86 |
+
use_rel_pos=False,
|
87 |
+
rel_pos_zero_init=True,
|
88 |
+
input_size=None,
|
89 |
+
interp_type="vitdet",
|
90 |
+
):
|
91 |
+
"""
|
92 |
+
Args:
|
93 |
+
dim (int): Number of input channels.
|
94 |
+
num_heads (int): Number of attention heads.
|
95 |
+
qkv_bias (bool: If True, add a learnable bias to query, key, value.
|
96 |
+
rel_pos (bool): If True, add relative positional embeddings to the attention map.
|
97 |
+
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
|
98 |
+
input_size (int or None): Input resolution for calculating the relative positional
|
99 |
+
parameter size.
|
100 |
+
"""
|
101 |
+
super().__init__()
|
102 |
+
self.num_heads = num_heads
|
103 |
+
head_dim = dim // num_heads
|
104 |
+
self.scale = head_dim**-0.5
|
105 |
+
|
106 |
+
self.beit_like_qkv_bias = beit_like_qkv_bias
|
107 |
+
if beit_like_qkv_bias:
|
108 |
+
self.q_bias = nn.Parameter(torch.zeros(dim))
|
109 |
+
self.v_bias = nn.Parameter(torch.zeros(dim))
|
110 |
+
|
111 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
112 |
+
self.proj = nn.Linear(dim, dim)
|
113 |
+
|
114 |
+
self.use_rel_pos = use_rel_pos
|
115 |
+
self.interp_type = interp_type
|
116 |
+
if self.use_rel_pos:
|
117 |
+
# initialize relative positional embeddings
|
118 |
+
self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
|
119 |
+
self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
|
120 |
+
|
121 |
+
if not rel_pos_zero_init:
|
122 |
+
trunc_normal_(self.rel_pos_h, std=0.02)
|
123 |
+
trunc_normal_(self.rel_pos_w, std=0.02)
|
124 |
+
self.qk_float = False
|
125 |
+
|
126 |
+
def forward(self, x):
|
127 |
+
B, H, W, _ = x.shape
|
128 |
+
# qkv with shape (3, B, nHead, H * W, C)
|
129 |
+
if self.beit_like_qkv_bias:
|
130 |
+
qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
|
131 |
+
qkv = torch.nn.functional.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
|
132 |
+
qkv = qkv.reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
133 |
+
else:
|
134 |
+
qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
135 |
+
# q, k, v with shape (B * nHead, H * W, C)
|
136 |
+
q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
|
137 |
+
|
138 |
+
if self.qk_float:
|
139 |
+
attn = (q.float() * self.scale) @ k.float().transpose(-2, -1)
|
140 |
+
if self.use_rel_pos:
|
141 |
+
attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W), self.interp_type)
|
142 |
+
attn = attn.softmax(dim=-1).type_as(x)
|
143 |
+
else:
|
144 |
+
attn = (q * self.scale) @ k.transpose(-2, -1)
|
145 |
+
if self.use_rel_pos:
|
146 |
+
attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W), self.interp_type)
|
147 |
+
attn = attn.softmax(dim=-1)
|
148 |
+
x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
|
149 |
+
x = self.proj(x)
|
150 |
+
|
151 |
+
return x
|
152 |
+
|
153 |
+
|
154 |
+
class ResBottleneckBlock(CNNBlockBase):
|
155 |
+
"""
|
156 |
+
The standard bottleneck residual block without the last activation layer.
|
157 |
+
It contains 3 conv layers with kernels 1x1, 3x3, 1x1.
|
158 |
+
"""
|
159 |
+
|
160 |
+
def __init__(
|
161 |
+
self,
|
162 |
+
in_channels,
|
163 |
+
out_channels,
|
164 |
+
bottleneck_channels,
|
165 |
+
norm="LN",
|
166 |
+
act_layer=nn.GELU,
|
167 |
+
):
|
168 |
+
"""
|
169 |
+
Args:
|
170 |
+
in_channels (int): Number of input channels.
|
171 |
+
out_channels (int): Number of output channels.
|
172 |
+
bottleneck_channels (int): number of output channels for the 3x3
|
173 |
+
"bottleneck" conv layers.
|
174 |
+
norm (str or callable): normalization for all conv layers.
|
175 |
+
See :func:`layers.get_norm` for supported format.
|
176 |
+
act_layer (callable): activation for all conv layers.
|
177 |
+
"""
|
178 |
+
super().__init__(in_channels, out_channels, 1)
|
179 |
+
|
180 |
+
self.conv1 = Conv2d(in_channels, bottleneck_channels, 1, bias=False)
|
181 |
+
self.norm1 = get_norm(norm, bottleneck_channels)
|
182 |
+
self.act1 = act_layer()
|
183 |
+
|
184 |
+
self.conv2 = Conv2d(
|
185 |
+
bottleneck_channels,
|
186 |
+
bottleneck_channels,
|
187 |
+
3,
|
188 |
+
padding=1,
|
189 |
+
bias=False,
|
190 |
+
)
|
191 |
+
self.norm2 = get_norm(norm, bottleneck_channels)
|
192 |
+
self.act2 = act_layer()
|
193 |
+
|
194 |
+
self.conv3 = Conv2d(bottleneck_channels, out_channels, 1, bias=False)
|
195 |
+
self.norm3 = get_norm(norm, out_channels)
|
196 |
+
|
197 |
+
for layer in [self.conv1, self.conv2, self.conv3]:
|
198 |
+
weight_init.c2_msra_fill(layer)
|
199 |
+
for layer in [self.norm1, self.norm2]:
|
200 |
+
layer.weight.data.fill_(1.0)
|
201 |
+
layer.bias.data.zero_()
|
202 |
+
# zero init last norm layer.
|
203 |
+
self.norm3.weight.data.zero_()
|
204 |
+
self.norm3.bias.data.zero_()
|
205 |
+
|
206 |
+
def forward(self, x):
|
207 |
+
out = x
|
208 |
+
for layer in self.children():
|
209 |
+
out = layer(out)
|
210 |
+
|
211 |
+
out = x + out
|
212 |
+
return out
|
213 |
+
|
214 |
+
|
215 |
+
class Block(nn.Module):
|
216 |
+
"""Transformer blocks with support of window attention and residual propagation blocks"""
|
217 |
+
|
218 |
+
def __init__(
|
219 |
+
self,
|
220 |
+
dim,
|
221 |
+
num_heads,
|
222 |
+
mlp_ratio=4.0,
|
223 |
+
qkv_bias=True,
|
224 |
+
drop_path=0.0,
|
225 |
+
norm_layer=LayerNormWithForceFP32,
|
226 |
+
act_layer=nn.GELU,
|
227 |
+
use_rel_pos=False,
|
228 |
+
rel_pos_zero_init=True,
|
229 |
+
window_size=0,
|
230 |
+
use_residual_block=False,
|
231 |
+
input_size=None,
|
232 |
+
beit_like_qkv_bias=False,
|
233 |
+
beit_like_gamma=False,
|
234 |
+
interp_type="vitdet",
|
235 |
+
):
|
236 |
+
"""
|
237 |
+
Args:
|
238 |
+
dim (int): Number of input channels.
|
239 |
+
num_heads (int): Number of attention heads in each ViT block.
|
240 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
241 |
+
qkv_bias (bool): If True, add a learnable bias to query, key, value.
|
242 |
+
drop_path (float): Stochastic depth rate.
|
243 |
+
norm_layer (nn.Module): Normalization layer.
|
244 |
+
act_layer (nn.Module): Activation layer.
|
245 |
+
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
|
246 |
+
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
|
247 |
+
window_size (int): Window size for window attention blocks. If it equals 0, then not
|
248 |
+
use window attention.
|
249 |
+
use_residual_block (bool): If True, use a residual block after the MLP block.
|
250 |
+
input_size (int or None): Input resolution for calculating the relative positional
|
251 |
+
parameter size.
|
252 |
+
beit_like_qkv_bias (bool)
|
253 |
+
beit_like_gamma (bool)
|
254 |
+
"""
|
255 |
+
super().__init__()
|
256 |
+
self.norm1 = norm_layer(dim)
|
257 |
+
self.attn = Attention(
|
258 |
+
dim,
|
259 |
+
num_heads=num_heads,
|
260 |
+
qkv_bias=qkv_bias,
|
261 |
+
use_rel_pos=use_rel_pos,
|
262 |
+
rel_pos_zero_init=rel_pos_zero_init,
|
263 |
+
input_size=input_size if window_size == 0 else (window_size, window_size),
|
264 |
+
beit_like_qkv_bias=beit_like_qkv_bias,
|
265 |
+
interp_type=interp_type,
|
266 |
+
)
|
267 |
+
|
268 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
269 |
+
self.norm2 = norm_layer(dim)
|
270 |
+
self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer)
|
271 |
+
|
272 |
+
self.window_size = window_size
|
273 |
+
|
274 |
+
self.use_residual_block = use_residual_block
|
275 |
+
if use_residual_block:
|
276 |
+
# Use a residual block with bottleneck channel as dim // 2
|
277 |
+
self.residual = ResBottleneckBlock(
|
278 |
+
in_channels=dim,
|
279 |
+
out_channels=dim,
|
280 |
+
bottleneck_channels=dim // 2,
|
281 |
+
norm="LN",
|
282 |
+
act_layer=act_layer,
|
283 |
+
)
|
284 |
+
|
285 |
+
self.beit_like_gamma = beit_like_gamma
|
286 |
+
if beit_like_gamma:
|
287 |
+
self.gamma_1 = nn.Parameter(torch.ones((dim)), requires_grad=True)
|
288 |
+
self.gamma_2 = nn.Parameter(torch.ones((dim)), requires_grad=True)
|
289 |
+
|
290 |
+
def forward(self, x):
|
291 |
+
shortcut = x
|
292 |
+
x = self.norm1(x)
|
293 |
+
# Window partition
|
294 |
+
if self.window_size > 0:
|
295 |
+
H, W = x.shape[1], x.shape[2]
|
296 |
+
x, pad_hw = window_partition(x, self.window_size)
|
297 |
+
|
298 |
+
x = self.attn(x)
|
299 |
+
# Reverse window partition
|
300 |
+
if self.window_size > 0:
|
301 |
+
x = window_unpartition(x, self.window_size, pad_hw, (H, W))
|
302 |
+
|
303 |
+
if self.beit_like_gamma:
|
304 |
+
x = shortcut + self.drop_path(self.gamma_1 * x)
|
305 |
+
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
|
306 |
+
else:
|
307 |
+
x = shortcut + self.drop_path(x)
|
308 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
309 |
+
|
310 |
+
if self.use_residual_block:
|
311 |
+
x = self.residual(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
|
312 |
+
|
313 |
+
return x
|
314 |
+
|
315 |
+
|
316 |
+
class EVAViT(Backbone):
|
317 |
+
"""
|
318 |
+
This module implements Vision Transformer (ViT) backbone in :paper:`vitdet`.
|
319 |
+
"Exploring Plain Vision Transformer Backbones for Object Detection",
|
320 |
+
https://arxiv.org/abs/2203.16527
|
321 |
+
"""
|
322 |
+
|
323 |
+
def __init__(
|
324 |
+
self,
|
325 |
+
img_size=1024,
|
326 |
+
patch_size=16,
|
327 |
+
in_chans=3,
|
328 |
+
embed_dim=768,
|
329 |
+
depth=12,
|
330 |
+
num_heads=12,
|
331 |
+
mlp_ratio=4.0,
|
332 |
+
qkv_bias=True,
|
333 |
+
drop_path_rate=0.0,
|
334 |
+
norm_layer=LayerNormWithForceFP32,
|
335 |
+
act_layer=nn.GELU,
|
336 |
+
use_abs_pos=True,
|
337 |
+
use_rel_pos=False,
|
338 |
+
rel_pos_zero_init=True,
|
339 |
+
window_size=0,
|
340 |
+
window_block_indexes=(),
|
341 |
+
residual_block_indexes=(),
|
342 |
+
use_act_checkpoint=False,
|
343 |
+
pretrain_img_size=224,
|
344 |
+
pretrain_use_cls_token=True,
|
345 |
+
out_feature="last_feat",
|
346 |
+
beit_like_qkv_bias=True,
|
347 |
+
beit_like_gamma=False,
|
348 |
+
freeze_patch_embed=False,
|
349 |
+
interp_type="vitdet",
|
350 |
+
):
|
351 |
+
"""
|
352 |
+
Args:
|
353 |
+
img_size (int): Input image size.
|
354 |
+
patch_size (int): Patch size.
|
355 |
+
in_chans (int): Number of input image channels.
|
356 |
+
embed_dim (int): Patch embedding dimension.
|
357 |
+
depth (int): Depth of ViT.
|
358 |
+
num_heads (int): Number of attention heads in each ViT block.
|
359 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
360 |
+
qkv_bias (bool): If True, add a learnable bias to query, key, value.
|
361 |
+
drop_path_rate (float): Stochastic depth rate.
|
362 |
+
norm_layer (nn.Module): Normalization layer.
|
363 |
+
act_layer (nn.Module): Activation layer.
|
364 |
+
use_abs_pos (bool): If True, use absolute positional embeddings.
|
365 |
+
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
|
366 |
+
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
|
367 |
+
window_size (int): Window size for window attention blocks.
|
368 |
+
window_block_indexes (list): Indexes for blocks using window attention.
|
369 |
+
residual_block_indexes (list): Indexes for blocks using conv propagation.
|
370 |
+
use_act_checkpoint (bool): If True, use activation checkpointing.
|
371 |
+
pretrain_img_size (int): input image size for pretraining models.
|
372 |
+
pretrain_use_cls_token (bool): If True, pretrainig models use class token.
|
373 |
+
out_feature (str): name of the feature from the last block.
|
374 |
+
beit_like_qkv_bias (bool): beit_like_model that has gamma_1 and gamma_2 in blocks and qkv_bias=False
|
375 |
+
beit_like_gamma (bool)
|
376 |
+
freeze_patch_embed (bool)
|
377 |
+
interp_type: "vitdet" for training / fine-ting, "beit" for eval (slightly improvement at a higher res)
|
378 |
+
"""
|
379 |
+
super().__init__()
|
380 |
+
self.pretrain_use_cls_token = pretrain_use_cls_token
|
381 |
+
|
382 |
+
self.patch_embed = PatchEmbed(
|
383 |
+
kernel_size=(patch_size, patch_size),
|
384 |
+
stride=(patch_size, patch_size),
|
385 |
+
in_chans=in_chans,
|
386 |
+
embed_dim=embed_dim,
|
387 |
+
)
|
388 |
+
|
389 |
+
if use_abs_pos:
|
390 |
+
# Initialize absolute positional embedding with pretrain image size.
|
391 |
+
num_patches = (pretrain_img_size // patch_size) * (pretrain_img_size // patch_size)
|
392 |
+
num_positions = (num_patches + 1) if pretrain_use_cls_token else num_patches
|
393 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_positions, embed_dim))
|
394 |
+
else:
|
395 |
+
self.pos_embed = None
|
396 |
+
|
397 |
+
# stochastic depth decay rule
|
398 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
|
399 |
+
|
400 |
+
self.blocks = nn.ModuleList()
|
401 |
+
if beit_like_qkv_bias:
|
402 |
+
qkv_bias = False
|
403 |
+
for i in range(depth):
|
404 |
+
block = Block(
|
405 |
+
dim=embed_dim,
|
406 |
+
num_heads=num_heads,
|
407 |
+
mlp_ratio=mlp_ratio,
|
408 |
+
qkv_bias=qkv_bias,
|
409 |
+
drop_path=dpr[i],
|
410 |
+
norm_layer=norm_layer,
|
411 |
+
act_layer=act_layer,
|
412 |
+
use_rel_pos=use_rel_pos,
|
413 |
+
rel_pos_zero_init=rel_pos_zero_init,
|
414 |
+
window_size=window_size if i in window_block_indexes else 0,
|
415 |
+
use_residual_block=i in residual_block_indexes,
|
416 |
+
input_size=(img_size // patch_size, img_size // patch_size),
|
417 |
+
beit_like_qkv_bias=beit_like_qkv_bias,
|
418 |
+
beit_like_gamma=beit_like_gamma,
|
419 |
+
interp_type=interp_type,
|
420 |
+
)
|
421 |
+
if use_act_checkpoint:
|
422 |
+
block = checkpoint_wrapper(block)
|
423 |
+
self.blocks.append(block)
|
424 |
+
|
425 |
+
self._out_feature_channels = {out_feature: embed_dim}
|
426 |
+
self._out_feature_strides = {out_feature: patch_size}
|
427 |
+
self._out_features = [out_feature]
|
428 |
+
|
429 |
+
if self.pos_embed is not None:
|
430 |
+
trunc_normal_(self.pos_embed, std=0.02)
|
431 |
+
|
432 |
+
self.freeze_patch_embed = freeze_patch_embed
|
433 |
+
self.apply(self._init_weights)
|
434 |
+
|
435 |
+
def _init_weights(self, m):
|
436 |
+
if isinstance(m, nn.Linear):
|
437 |
+
trunc_normal_(m.weight, std=0.02)
|
438 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
439 |
+
nn.init.constant_(m.bias, 0)
|
440 |
+
elif isinstance(m, LayerNormWithForceFP32):
|
441 |
+
nn.init.constant_(m.bias, 0)
|
442 |
+
nn.init.constant_(m.weight, 1.0)
|
443 |
+
|
444 |
+
if self.freeze_patch_embed:
|
445 |
+
for n, p in self.patch_embed.named_parameters():
|
446 |
+
p.requires_grad = False
|
447 |
+
|
448 |
+
def forward(self, x):
|
449 |
+
x = self.patch_embed(x)
|
450 |
+
if self.pos_embed is not None:
|
451 |
+
x = x + get_abs_pos(
|
452 |
+
self.pos_embed, self.pretrain_use_cls_token, (x.shape[1], x.shape[2])
|
453 |
+
)
|
454 |
+
|
455 |
+
for blk in self.blocks:
|
456 |
+
x = blk(x)
|
457 |
+
|
458 |
+
outputs = {self._out_features[0]: x.permute(0, 3, 1, 2)}
|
459 |
+
return outputs
|
460 |
+
|
461 |
+
|
462 |
+
class SimpleFeaturePyramid(Backbone):
|
463 |
+
"""
|
464 |
+
This module implements SimpleFeaturePyramid in :paper:`vitdet`.
|
465 |
+
It creates pyramid features built on top of the input feature map.
|
466 |
+
"""
|
467 |
+
|
468 |
+
def __init__(
|
469 |
+
self,
|
470 |
+
net,
|
471 |
+
in_feature,
|
472 |
+
out_channels,
|
473 |
+
scale_factors,
|
474 |
+
top_block=None,
|
475 |
+
norm="LN",
|
476 |
+
square_pad=0,
|
477 |
+
):
|
478 |
+
"""
|
479 |
+
Args:
|
480 |
+
net (Backbone): module representing the subnetwork backbone.
|
481 |
+
Must be a subclass of :class:`Backbone`.
|
482 |
+
in_feature (str): names of the input feature maps coming
|
483 |
+
from the net.
|
484 |
+
out_channels (int): number of channels in the output feature maps.
|
485 |
+
scale_factors (list[float]): list of scaling factors to upsample or downsample
|
486 |
+
the input features for creating pyramid features.
|
487 |
+
top_block (nn.Module or None): if provided, an extra operation will
|
488 |
+
be performed on the output of the last (smallest resolution)
|
489 |
+
pyramid output, and the result will extend the result list. The top_block
|
490 |
+
further downsamples the feature map. It must have an attribute
|
491 |
+
"num_levels", meaning the number of extra pyramid levels added by
|
492 |
+
this block, and "in_feature", which is a string representing
|
493 |
+
its input feature (e.g., p5).
|
494 |
+
norm (str): the normalization to use.
|
495 |
+
square_pad (int): If > 0, require input images to be padded to specific square size.
|
496 |
+
"""
|
497 |
+
super(SimpleFeaturePyramid, self).__init__()
|
498 |
+
assert isinstance(net, Backbone)
|
499 |
+
self.scale_factors = scale_factors
|
500 |
+
|
501 |
+
input_shapes = net.output_shape()
|
502 |
+
strides = [int(input_shapes[in_feature].stride / scale) for scale in scale_factors]
|
503 |
+
_assert_strides_are_log2_contiguous(strides)
|
504 |
+
|
505 |
+
dim = input_shapes[in_feature].channels
|
506 |
+
self.stages = []
|
507 |
+
use_bias = norm == ""
|
508 |
+
for idx, scale in enumerate(scale_factors):
|
509 |
+
out_dim = dim
|
510 |
+
if scale == 4.0:
|
511 |
+
layers = [
|
512 |
+
nn.ConvTranspose2d(dim, dim // 2, kernel_size=2, stride=2),
|
513 |
+
get_norm(norm, dim // 2),
|
514 |
+
nn.GELU(),
|
515 |
+
nn.ConvTranspose2d(dim // 2, dim // 4, kernel_size=2, stride=2),
|
516 |
+
]
|
517 |
+
out_dim = dim // 4
|
518 |
+
elif scale == 2.0:
|
519 |
+
layers = [nn.ConvTranspose2d(dim, dim // 2, kernel_size=2, stride=2)]
|
520 |
+
out_dim = dim // 2
|
521 |
+
elif scale == 1.0:
|
522 |
+
layers = []
|
523 |
+
elif scale == 0.5:
|
524 |
+
layers = [nn.MaxPool2d(kernel_size=2, stride=2)]
|
525 |
+
else:
|
526 |
+
raise NotImplementedError(f"scale_factor={scale} is not supported yet.")
|
527 |
+
|
528 |
+
layers.extend(
|
529 |
+
[
|
530 |
+
Conv2d(
|
531 |
+
out_dim,
|
532 |
+
out_channels,
|
533 |
+
kernel_size=1,
|
534 |
+
bias=use_bias,
|
535 |
+
norm=get_norm(norm, out_channels),
|
536 |
+
),
|
537 |
+
Conv2d(
|
538 |
+
out_channels,
|
539 |
+
out_channels,
|
540 |
+
kernel_size=3,
|
541 |
+
padding=1,
|
542 |
+
bias=use_bias,
|
543 |
+
norm=get_norm(norm, out_channels),
|
544 |
+
),
|
545 |
+
]
|
546 |
+
)
|
547 |
+
layers = nn.Sequential(*layers)
|
548 |
+
|
549 |
+
stage = int(math.log2(strides[idx]))
|
550 |
+
self.add_module(f"simfp_{stage}", layers)
|
551 |
+
self.stages.append(layers)
|
552 |
+
|
553 |
+
self.net = net
|
554 |
+
self.in_feature = in_feature
|
555 |
+
self.top_block = top_block
|
556 |
+
# Return feature names are "p<stage>", like ["p2", "p3", ..., "p6"]
|
557 |
+
self._out_feature_strides = {"p{}".format(int(math.log2(s))): s for s in strides}
|
558 |
+
# top block output feature maps.
|
559 |
+
if self.top_block is not None:
|
560 |
+
for s in range(stage, stage + self.top_block.num_levels):
|
561 |
+
self._out_feature_strides["p{}".format(s + 1)] = 2 ** (s + 1)
|
562 |
+
|
563 |
+
self._out_features = list(self._out_feature_strides.keys())
|
564 |
+
self._out_feature_channels = {k: out_channels for k in self._out_features}
|
565 |
+
self._size_divisibility = strides[-1]
|
566 |
+
self._square_pad = square_pad
|
567 |
+
|
568 |
+
@property
|
569 |
+
def padding_constraints(self):
|
570 |
+
return {
|
571 |
+
"size_divisiblity": self._size_divisibility,
|
572 |
+
"square_size": self._square_pad,
|
573 |
+
}
|
574 |
+
|
575 |
+
def forward(self, x):
|
576 |
+
"""
|
577 |
+
Args:
|
578 |
+
x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``.
|
579 |
+
|
580 |
+
Returns:
|
581 |
+
dict[str->Tensor]:
|
582 |
+
mapping from feature map name to pyramid feature map tensor
|
583 |
+
in high to low resolution order. Returned feature names follow the FPN
|
584 |
+
convention: "p<stage>", where stage has stride = 2 ** stage e.g.,
|
585 |
+
["p2", "p3", ..., "p6"].
|
586 |
+
"""
|
587 |
+
bottom_up_features = self.net(x)
|
588 |
+
features = bottom_up_features[self.in_feature]
|
589 |
+
results = []
|
590 |
+
|
591 |
+
for stage in self.stages:
|
592 |
+
results.append(stage(features))
|
593 |
+
|
594 |
+
if self.top_block is not None:
|
595 |
+
if self.top_block.in_feature in bottom_up_features:
|
596 |
+
top_block_in_feature = bottom_up_features[self.top_block.in_feature]
|
597 |
+
else:
|
598 |
+
top_block_in_feature = results[self._out_features.index(self.top_block.in_feature)]
|
599 |
+
results.extend(self.top_block(top_block_in_feature))
|
600 |
+
assert len(self._out_features) == len(results)
|
601 |
+
return {f: res for f, res in zip(self._out_features, results)}
|
602 |
+
|
603 |
+
|
604 |
+
|
605 |
+
@BACKBONE_REGISTRY.register()
|
606 |
+
class D2_EVA01(SimpleFeaturePyramid):
|
607 |
+
def __init__(self, cfg, input_shape):
|
608 |
+
|
609 |
+
super().__init__(
|
610 |
+
net = EVAViT(
|
611 |
+
img_size= cfg.MODEL.EVA01.IMAGE_SIZE,
|
612 |
+
patch_size=cfg.MODEL.EVA01.PATCH_SIZE,
|
613 |
+
window_size= cfg.MODEL.EVA01.WINDOW_SIZE,
|
614 |
+
embed_dim= cfg.MODEL.EVA01.DMBED_DIM,
|
615 |
+
depth= cfg.MODEL.EVA01.DEPTH,
|
616 |
+
num_heads= cfg.MODEL.EVA01.NUM_HEADS ,
|
617 |
+
drop_path_rate= cfg.MODEL.EVA01.DROP_PATH_RATE,
|
618 |
+
mlp_ratio= cfg.MODEL.EVA01.MLP_RATIO,
|
619 |
+
qkv_bias=True,
|
620 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
621 |
+
window_block_indexes= cfg.MODEL.EVA01.WINDOW_BLOCK_INDEXES,
|
622 |
+
residual_block_indexes=[],
|
623 |
+
use_act_checkpoint = True,
|
624 |
+
use_rel_pos = True,
|
625 |
+
out_feature="last_feat",
|
626 |
+
beit_like_qkv_bias=cfg.MODEL.EVA01.BEIT_LIKE_QKV_BIAS ,
|
627 |
+
beit_like_gamma= cfg.MODEL.EVA01.BEIT_LIKE_GAMMA,
|
628 |
+
freeze_patch_embed= cfg.MODEL.EVA01.FREEZE_PATH_EMBED,
|
629 |
+
),
|
630 |
+
in_feature = "last_feat",
|
631 |
+
out_channels=256,
|
632 |
+
scale_factors=(2.0, 1.0, 0.5), # (4.0, 2.0, 1.0, 0.5) in ViTDet
|
633 |
+
top_block=LastLevelMaxPool(),
|
634 |
+
norm="LN",
|
635 |
+
square_pad=cfg.MODEL.EVA01.IMAGE_SIZE,
|
636 |
+
|
637 |
+
)
|
638 |
+
pretrained_weight = cfg.MODEL.EVA01.PRETRAINED_WEIGHT
|
639 |
+
if pretrained_weight:
|
640 |
+
checkpoint = torch.load(pretrained_weight, map_location='cpu')
|
641 |
+
print(f'\nload pretrain weight from {pretrained_weight} \n')
|
642 |
+
self.load_state_dict(checkpoint['model'], strict=False)
|
643 |
+
|
644 |
+
def output_shape(self):
|
645 |
+
return {
|
646 |
+
name: ShapeSpec(
|
647 |
+
channels=self._out_feature_channels[name], stride=self._out_feature_strides[name]
|
648 |
+
)
|
649 |
+
for name in self._out_features
|
650 |
+
}
|
651 |
+
|
652 |
+
@property
|
653 |
+
def size_divisibility(self):
|
654 |
+
return 32
|
655 |
+
|
656 |
+
|
657 |
+
|
658 |
+
def get_vit_lr_decay_rate(name, lr_decay_rate=1.0, num_layers=12):
|
659 |
+
"""
|
660 |
+
Calculate lr decay rate for different ViT blocks.
|
661 |
+
Args:
|
662 |
+
name (string): parameter name.
|
663 |
+
lr_decay_rate (float): base lr decay rate.
|
664 |
+
num_layers (int): number of ViT blocks.
|
665 |
+
|
666 |
+
Returns:
|
667 |
+
lr decay rate for the given parameter.
|
668 |
+
"""
|
669 |
+
layer_id = num_layers + 1
|
670 |
+
if 'backbone' in name: #name.startswith("backbone"):
|
671 |
+
if ".pos_embed" in name or ".patch_embed" in name:
|
672 |
+
layer_id = 0
|
673 |
+
elif ".blocks." in name and ".residual." not in name:
|
674 |
+
layer_id = int(name[name.find(".blocks.") :].split(".")[2]) + 1
|
675 |
+
|
676 |
+
return lr_decay_rate ** (num_layers + 1 - layer_id)
|
GLEE/glee/backbone/eva02-dino.py
ADDED
@@ -0,0 +1,598 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import math
|
3 |
+
from functools import partial
|
4 |
+
|
5 |
+
import fvcore.nn.weight_init as weight_init
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
from detectron2.layers import CNNBlockBase, Conv2d, get_norm
|
11 |
+
from detectron2.modeling.backbone.fpn import _assert_strides_are_log2_contiguous
|
12 |
+
|
13 |
+
from detectron2.modeling.backbone import Backbone
|
14 |
+
from .eva_02_utils import (
|
15 |
+
PatchEmbed,
|
16 |
+
add_decomposed_rel_pos,
|
17 |
+
get_abs_pos,
|
18 |
+
window_partition,
|
19 |
+
window_unpartition,
|
20 |
+
VisionRotaryEmbeddingFast,
|
21 |
+
)
|
22 |
+
|
23 |
+
try:
|
24 |
+
import xformers.ops as xops
|
25 |
+
HAS_XFORMER=True
|
26 |
+
except:
|
27 |
+
HAS_XFORMER=False
|
28 |
+
pass
|
29 |
+
|
30 |
+
|
31 |
+
logger = logging.getLogger(__name__)
|
32 |
+
|
33 |
+
|
34 |
+
|
35 |
+
__all__ = ["EVA02_ViT", "SimpleFeaturePyramid", "get_vit_lr_decay_rate"]
|
36 |
+
|
37 |
+
|
38 |
+
|
39 |
+
class SwiGLU(nn.Module):
|
40 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.SiLU, drop=0.,
|
41 |
+
norm_layer=nn.LayerNorm, subln=False
|
42 |
+
):
|
43 |
+
super().__init__()
|
44 |
+
out_features = out_features or in_features
|
45 |
+
hidden_features = hidden_features or in_features
|
46 |
+
|
47 |
+
self.w1 = nn.Linear(in_features, hidden_features)
|
48 |
+
self.w2 = nn.Linear(in_features, hidden_features)
|
49 |
+
|
50 |
+
self.act = act_layer()
|
51 |
+
self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity()
|
52 |
+
self.w3 = nn.Linear(hidden_features, out_features)
|
53 |
+
|
54 |
+
self.drop = nn.Dropout(drop)
|
55 |
+
|
56 |
+
def forward(self, x):
|
57 |
+
x1 = self.w1(x)
|
58 |
+
x2 = self.w2(x)
|
59 |
+
hidden = self.act(x1) * x2
|
60 |
+
x = self.ffn_ln(hidden)
|
61 |
+
x = self.w3(x)
|
62 |
+
x = self.drop(x)
|
63 |
+
return x
|
64 |
+
|
65 |
+
|
66 |
+
class Attention(nn.Module):
|
67 |
+
def __init__(
|
68 |
+
self,
|
69 |
+
dim,
|
70 |
+
num_heads=8,
|
71 |
+
qkv_bias=True,
|
72 |
+
qk_scale=None,
|
73 |
+
attn_head_dim=None,
|
74 |
+
rope=None,
|
75 |
+
xattn=True,
|
76 |
+
):
|
77 |
+
super().__init__()
|
78 |
+
self.num_heads = num_heads
|
79 |
+
head_dim = dim // num_heads
|
80 |
+
if attn_head_dim is not None:
|
81 |
+
head_dim = attn_head_dim
|
82 |
+
all_head_dim = head_dim * self.num_heads
|
83 |
+
self.scale = qk_scale or head_dim ** -0.5
|
84 |
+
|
85 |
+
self.q_proj = nn.Linear(dim, all_head_dim, bias=False)
|
86 |
+
self.k_proj = nn.Linear(dim, all_head_dim, bias=False)
|
87 |
+
self.v_proj = nn.Linear(dim, all_head_dim, bias=False)
|
88 |
+
|
89 |
+
if qkv_bias:
|
90 |
+
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
|
91 |
+
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
|
92 |
+
else:
|
93 |
+
self.q_bias = None
|
94 |
+
self.v_bias = None
|
95 |
+
|
96 |
+
self.rope = rope
|
97 |
+
self.xattn = xattn
|
98 |
+
self.proj = nn.Linear(all_head_dim, dim)
|
99 |
+
|
100 |
+
if not HAS_XFORMER:
|
101 |
+
self.xattn = False
|
102 |
+
|
103 |
+
def forward(self, x):
|
104 |
+
B, H, W, C = x.shape
|
105 |
+
x = x.view(B, -1, C)
|
106 |
+
N = H * W
|
107 |
+
|
108 |
+
q = F.linear(input=x, weight=self.q_proj.weight, bias=self.q_bias)
|
109 |
+
k = F.linear(input=x, weight=self.k_proj.weight, bias=None)
|
110 |
+
v = F.linear(input=x, weight=self.v_proj.weight, bias=self.v_bias)
|
111 |
+
|
112 |
+
q = q.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3) # B, num_heads, N, C
|
113 |
+
k = k.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
|
114 |
+
v = v.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
|
115 |
+
|
116 |
+
## rope
|
117 |
+
q = self.rope(q).type_as(v)
|
118 |
+
k = self.rope(k).type_as(v)
|
119 |
+
|
120 |
+
if self.xattn:
|
121 |
+
q = q.permute(0, 2, 1, 3) # B, num_heads, N, C -> B, N, num_heads, C
|
122 |
+
k = k.permute(0, 2, 1, 3)
|
123 |
+
v = v.permute(0, 2, 1, 3)
|
124 |
+
|
125 |
+
x = xops.memory_efficient_attention(q, k, v)
|
126 |
+
x = x.reshape(B, N, -1)
|
127 |
+
else:
|
128 |
+
q = q * self.scale
|
129 |
+
attn = (q @ k.transpose(-2, -1))
|
130 |
+
attn = attn.softmax(dim=-1).type_as(x)
|
131 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
|
132 |
+
|
133 |
+
x = self.proj(x)
|
134 |
+
x = x.view(B, H, W, C)
|
135 |
+
|
136 |
+
return x
|
137 |
+
|
138 |
+
|
139 |
+
class ResBottleneckBlock(CNNBlockBase):
|
140 |
+
"""
|
141 |
+
The standard bottleneck residual block without the last activation layer.
|
142 |
+
It contains 3 conv layers with kernels 1x1, 3x3, 1x1.
|
143 |
+
"""
|
144 |
+
|
145 |
+
def __init__(
|
146 |
+
self,
|
147 |
+
in_channels,
|
148 |
+
out_channels,
|
149 |
+
bottleneck_channels,
|
150 |
+
norm="LN",
|
151 |
+
act_layer=nn.GELU,
|
152 |
+
):
|
153 |
+
"""
|
154 |
+
Args:
|
155 |
+
in_channels (int): Number of input channels.
|
156 |
+
out_channels (int): Number of output channels.
|
157 |
+
bottleneck_channels (int): number of output channels for the 3x3
|
158 |
+
"bottleneck" conv layers.
|
159 |
+
norm (str or callable): normalization for all conv layers.
|
160 |
+
See :func:`layers.get_norm` for supported format.
|
161 |
+
act_layer (callable): activation for all conv layers.
|
162 |
+
"""
|
163 |
+
super().__init__(in_channels, out_channels, 1)
|
164 |
+
|
165 |
+
self.conv1 = Conv2d(in_channels, bottleneck_channels, 1, bias=False)
|
166 |
+
self.norm1 = get_norm(norm, bottleneck_channels)
|
167 |
+
self.act1 = act_layer()
|
168 |
+
|
169 |
+
self.conv2 = Conv2d(
|
170 |
+
bottleneck_channels,
|
171 |
+
bottleneck_channels,
|
172 |
+
3,
|
173 |
+
padding=1,
|
174 |
+
bias=False,
|
175 |
+
)
|
176 |
+
self.norm2 = get_norm(norm, bottleneck_channels)
|
177 |
+
self.act2 = act_layer()
|
178 |
+
|
179 |
+
self.conv3 = Conv2d(bottleneck_channels, out_channels, 1, bias=False)
|
180 |
+
self.norm3 = get_norm(norm, out_channels)
|
181 |
+
|
182 |
+
for layer in [self.conv1, self.conv2, self.conv3]:
|
183 |
+
weight_init.c2_msra_fill(layer)
|
184 |
+
for layer in [self.norm1, self.norm2]:
|
185 |
+
layer.weight.data.fill_(1.0)
|
186 |
+
layer.bias.data.zero_()
|
187 |
+
# zero init last norm layer.
|
188 |
+
self.norm3.weight.data.zero_()
|
189 |
+
self.norm3.bias.data.zero_()
|
190 |
+
|
191 |
+
def forward(self, x):
|
192 |
+
out = x
|
193 |
+
for layer in self.children():
|
194 |
+
out = layer(out)
|
195 |
+
|
196 |
+
out = x + out
|
197 |
+
return out
|
198 |
+
|
199 |
+
|
200 |
+
class Block(nn.Module):
|
201 |
+
"""Transformer blocks with support of window attention and residual propagation blocks"""
|
202 |
+
|
203 |
+
def __init__(
|
204 |
+
self,
|
205 |
+
dim,
|
206 |
+
num_heads,
|
207 |
+
mlp_ratio=4*2/3,
|
208 |
+
qkv_bias=True,
|
209 |
+
drop_path=0.0,
|
210 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
211 |
+
window_size=0,
|
212 |
+
use_residual_block=False,
|
213 |
+
rope=None,
|
214 |
+
xattn=True,
|
215 |
+
):
|
216 |
+
"""
|
217 |
+
Args:
|
218 |
+
dim (int): Number of input channels.
|
219 |
+
num_heads (int): Number of attention heads in each ViT block.
|
220 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
221 |
+
qkv_bias (bool): If True, add a learnable bias to query, key, value.
|
222 |
+
drop_path (float): Stochastic depth rate.
|
223 |
+
norm_layer (nn.Module): Normalization layer.
|
224 |
+
act_layer (nn.Module): Activation layer.
|
225 |
+
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
|
226 |
+
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
|
227 |
+
window_size (int): Window size for window attention blocks. If it equals 0, then not
|
228 |
+
use window attention.
|
229 |
+
use_residual_block (bool): If True, use a residual block after the MLP block.
|
230 |
+
input_size (int or None): Input resolution for calculating the relative positional
|
231 |
+
parameter size.
|
232 |
+
"""
|
233 |
+
super().__init__()
|
234 |
+
self.norm1 = norm_layer(dim)
|
235 |
+
self.attn = Attention(
|
236 |
+
dim,
|
237 |
+
num_heads=num_heads,
|
238 |
+
qkv_bias=qkv_bias,
|
239 |
+
rope=rope,
|
240 |
+
xattn=xattn,
|
241 |
+
)
|
242 |
+
|
243 |
+
from timm.models.layers import DropPath
|
244 |
+
|
245 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
246 |
+
self.norm2 = norm_layer(dim)
|
247 |
+
self.mlp = SwiGLU(
|
248 |
+
in_features=dim,
|
249 |
+
hidden_features=int(dim * mlp_ratio),
|
250 |
+
subln=True,
|
251 |
+
norm_layer=norm_layer,
|
252 |
+
)
|
253 |
+
|
254 |
+
self.window_size = window_size
|
255 |
+
|
256 |
+
self.use_residual_block = use_residual_block
|
257 |
+
if use_residual_block:
|
258 |
+
# Use a residual block with bottleneck channel as dim // 2
|
259 |
+
self.residual = ResBottleneckBlock(
|
260 |
+
in_channels=dim,
|
261 |
+
out_channels=dim,
|
262 |
+
bottleneck_channels=dim // 2,
|
263 |
+
norm="LN",
|
264 |
+
)
|
265 |
+
|
266 |
+
def forward(self, x):
|
267 |
+
shortcut = x
|
268 |
+
x = self.norm1(x)
|
269 |
+
|
270 |
+
# Window partition
|
271 |
+
if self.window_size > 0:
|
272 |
+
H, W = x.shape[1], x.shape[2]
|
273 |
+
x, pad_hw = window_partition(x, self.window_size)
|
274 |
+
|
275 |
+
x = self.attn(x)
|
276 |
+
|
277 |
+
# Reverse window partition
|
278 |
+
if self.window_size > 0:
|
279 |
+
x = window_unpartition(x, self.window_size, pad_hw, (H, W))
|
280 |
+
|
281 |
+
x = shortcut + self.drop_path(x)
|
282 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
283 |
+
|
284 |
+
if self.use_residual_block:
|
285 |
+
x = self.residual(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
|
286 |
+
|
287 |
+
return x
|
288 |
+
|
289 |
+
|
290 |
+
class EVA02_ViT(Backbone):
|
291 |
+
"""
|
292 |
+
This module implements Vision Transformer (ViT) backbone in :paper:`vitdet`.
|
293 |
+
"Exploring Plain Vision Transformer Backbones for Object Detection",
|
294 |
+
https://arxiv.org/abs/2203.16527
|
295 |
+
"""
|
296 |
+
|
297 |
+
def __init__(
|
298 |
+
self,
|
299 |
+
img_size=1024,
|
300 |
+
patch_size=16,
|
301 |
+
in_chans=3,
|
302 |
+
embed_dim=768,
|
303 |
+
depth=12,
|
304 |
+
num_heads=12,
|
305 |
+
mlp_ratio=4*2/3,
|
306 |
+
qkv_bias=True,
|
307 |
+
drop_path_rate=0.0,
|
308 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
309 |
+
act_layer=nn.GELU,
|
310 |
+
use_abs_pos=True,
|
311 |
+
use_rel_pos=False,
|
312 |
+
rope=True,
|
313 |
+
pt_hw_seq_len=16,
|
314 |
+
intp_freq=True,
|
315 |
+
window_size=0,
|
316 |
+
window_block_indexes=(),
|
317 |
+
residual_block_indexes=(),
|
318 |
+
use_act_checkpoint=False,
|
319 |
+
pretrain_img_size=224,
|
320 |
+
pretrain_use_cls_token=True,
|
321 |
+
out_feature="last_feat",
|
322 |
+
xattn=True,
|
323 |
+
):
|
324 |
+
"""
|
325 |
+
Args:
|
326 |
+
img_size (int): Input image size.
|
327 |
+
patch_size (int): Patch size.
|
328 |
+
in_chans (int): Number of input image channels.
|
329 |
+
embed_dim (int): Patch embedding dimension.
|
330 |
+
depth (int): Depth of ViT.
|
331 |
+
num_heads (int): Number of attention heads in each ViT block.
|
332 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
333 |
+
qkv_bias (bool): If True, add a learnable bias to query, key, value.
|
334 |
+
drop_path_rate (float): Stochastic depth rate.
|
335 |
+
norm_layer (nn.Module): Normalization layer.
|
336 |
+
act_layer (nn.Module): Activation layer.
|
337 |
+
use_abs_pos (bool): If True, use absolute positional embeddings.
|
338 |
+
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
|
339 |
+
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
|
340 |
+
window_size (int): Window size for window attention blocks.
|
341 |
+
window_block_indexes (list): Indexes for blocks using window attention.
|
342 |
+
residual_block_indexes (list): Indexes for blocks using conv propagation.
|
343 |
+
use_act_checkpoint (bool): If True, use activation checkpointing.
|
344 |
+
pretrain_img_size (int): input image size for pretraining models.
|
345 |
+
pretrain_use_cls_token (bool): If True, pretrainig models use class token.
|
346 |
+
out_feature (str): name of the feature from the last block.
|
347 |
+
"""
|
348 |
+
super().__init__()
|
349 |
+
self.pretrain_use_cls_token = pretrain_use_cls_token
|
350 |
+
|
351 |
+
self.patch_embed = PatchEmbed(
|
352 |
+
kernel_size=(patch_size, patch_size),
|
353 |
+
stride=(patch_size, patch_size),
|
354 |
+
in_chans=in_chans,
|
355 |
+
embed_dim=embed_dim,
|
356 |
+
)
|
357 |
+
|
358 |
+
if use_abs_pos:
|
359 |
+
# Initialize absolute positional embedding with pretrain image size.
|
360 |
+
num_patches = (pretrain_img_size // patch_size) * (pretrain_img_size // patch_size)
|
361 |
+
num_positions = (num_patches + 1) if pretrain_use_cls_token else num_patches
|
362 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_positions, embed_dim))
|
363 |
+
else:
|
364 |
+
self.pos_embed = None
|
365 |
+
|
366 |
+
|
367 |
+
half_head_dim = embed_dim // num_heads // 2
|
368 |
+
hw_seq_len = img_size // patch_size
|
369 |
+
|
370 |
+
self.rope_win = VisionRotaryEmbeddingFast(
|
371 |
+
dim=half_head_dim,
|
372 |
+
pt_seq_len=pt_hw_seq_len,
|
373 |
+
ft_seq_len=window_size if intp_freq else None,
|
374 |
+
)
|
375 |
+
self.rope_glb = VisionRotaryEmbeddingFast(
|
376 |
+
dim=half_head_dim,
|
377 |
+
pt_seq_len=pt_hw_seq_len,
|
378 |
+
ft_seq_len=hw_seq_len if intp_freq else None,
|
379 |
+
)
|
380 |
+
|
381 |
+
# stochastic depth decay rule
|
382 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
|
383 |
+
|
384 |
+
self.blocks = nn.ModuleList()
|
385 |
+
for i in range(depth):
|
386 |
+
block = Block(
|
387 |
+
dim=embed_dim,
|
388 |
+
num_heads=num_heads,
|
389 |
+
mlp_ratio=mlp_ratio,
|
390 |
+
qkv_bias=qkv_bias,
|
391 |
+
drop_path=dpr[i],
|
392 |
+
norm_layer=norm_layer,
|
393 |
+
window_size=window_size if i in window_block_indexes else 0,
|
394 |
+
use_residual_block=i in residual_block_indexes,
|
395 |
+
rope=self.rope_win if i in window_block_indexes else self.rope_glb,
|
396 |
+
xattn=xattn
|
397 |
+
)
|
398 |
+
if use_act_checkpoint:
|
399 |
+
# TODO: use torch.utils.checkpoint
|
400 |
+
from fairscale.nn.checkpoint import checkpoint_wrapper
|
401 |
+
|
402 |
+
block = checkpoint_wrapper(block)
|
403 |
+
self.blocks.append(block)
|
404 |
+
|
405 |
+
self._out_feature_channels = {out_feature: embed_dim}
|
406 |
+
self._out_feature_strides = {out_feature: patch_size}
|
407 |
+
self._out_features = [out_feature]
|
408 |
+
|
409 |
+
if self.pos_embed is not None:
|
410 |
+
nn.init.trunc_normal_(self.pos_embed, std=0.02)
|
411 |
+
|
412 |
+
self.apply(self._init_weights)
|
413 |
+
|
414 |
+
def _init_weights(self, m):
|
415 |
+
if isinstance(m, nn.Linear):
|
416 |
+
nn.init.trunc_normal_(m.weight, std=0.02)
|
417 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
418 |
+
nn.init.constant_(m.bias, 0)
|
419 |
+
elif isinstance(m, nn.LayerNorm):
|
420 |
+
nn.init.constant_(m.bias, 0)
|
421 |
+
nn.init.constant_(m.weight, 1.0)
|
422 |
+
|
423 |
+
def forward(self, x):
|
424 |
+
x = self.patch_embed(x)
|
425 |
+
if self.pos_embed is not None:
|
426 |
+
x = x + get_abs_pos(
|
427 |
+
self.pos_embed, self.pretrain_use_cls_token, (x.shape[1], x.shape[2])
|
428 |
+
)
|
429 |
+
|
430 |
+
for blk in self.blocks:
|
431 |
+
x = blk(x)
|
432 |
+
|
433 |
+
outputs = {self._out_features[0]: x.permute(0, 3, 1, 2)}
|
434 |
+
return outputs
|
435 |
+
|
436 |
+
|
437 |
+
class SimpleFeaturePyramid(Backbone):
|
438 |
+
"""
|
439 |
+
This module implements SimpleFeaturePyramid in :paper:`vitdet`.
|
440 |
+
It creates pyramid features built on top of the input feature map.
|
441 |
+
"""
|
442 |
+
|
443 |
+
def __init__(
|
444 |
+
self,
|
445 |
+
net,
|
446 |
+
in_feature,
|
447 |
+
out_channels,
|
448 |
+
scale_factors,
|
449 |
+
top_block=None,
|
450 |
+
norm="LN",
|
451 |
+
square_pad=0,
|
452 |
+
):
|
453 |
+
"""
|
454 |
+
Args:
|
455 |
+
net (Backbone): module representing the subnetwork backbone.
|
456 |
+
Must be a subclass of :class:`Backbone`.
|
457 |
+
in_feature (str): names of the input feature maps coming
|
458 |
+
from the net.
|
459 |
+
out_channels (int): number of channels in the output feature maps.
|
460 |
+
scale_factors (list[float]): list of scaling factors to upsample or downsample
|
461 |
+
the input features for creating pyramid features.
|
462 |
+
top_block (nn.Module or None): if provided, an extra operation will
|
463 |
+
be performed on the output of the last (smallest resolution)
|
464 |
+
pyramid output, and the result will extend the result list. The top_block
|
465 |
+
further downsamples the feature map. It must have an attribute
|
466 |
+
"num_levels", meaning the number of extra pyramid levels added by
|
467 |
+
this block, and "in_feature", which is a string representing
|
468 |
+
its input feature (e.g., p5).
|
469 |
+
norm (str): the normalization to use.
|
470 |
+
square_pad (int): If > 0, require input images to be padded to specific square size.
|
471 |
+
"""
|
472 |
+
super(SimpleFeaturePyramid, self).__init__()
|
473 |
+
assert isinstance(net, Backbone)
|
474 |
+
|
475 |
+
self.scale_factors = scale_factors
|
476 |
+
|
477 |
+
input_shapes = net.output_shape()
|
478 |
+
strides = [int(input_shapes[in_feature].stride / scale) for scale in scale_factors]
|
479 |
+
_assert_strides_are_log2_contiguous(strides)
|
480 |
+
|
481 |
+
dim = input_shapes[in_feature].channels
|
482 |
+
self.stages = []
|
483 |
+
use_bias = norm == ""
|
484 |
+
for idx, scale in enumerate(scale_factors):
|
485 |
+
out_dim = dim
|
486 |
+
if scale == 4.0:
|
487 |
+
layers = [
|
488 |
+
nn.ConvTranspose2d(dim, dim // 2, kernel_size=2, stride=2),
|
489 |
+
get_norm(norm, dim // 2),
|
490 |
+
nn.GELU(),
|
491 |
+
nn.ConvTranspose2d(dim // 2, dim // 4, kernel_size=2, stride=2),
|
492 |
+
]
|
493 |
+
out_dim = dim // 4
|
494 |
+
elif scale == 2.0:
|
495 |
+
layers = [nn.ConvTranspose2d(dim, dim // 2, kernel_size=2, stride=2)]
|
496 |
+
out_dim = dim // 2
|
497 |
+
elif scale == 1.0:
|
498 |
+
layers = []
|
499 |
+
elif scale == 0.5:
|
500 |
+
layers = [nn.MaxPool2d(kernel_size=2, stride=2)]
|
501 |
+
else:
|
502 |
+
raise NotImplementedError(f"scale_factor={scale} is not supported yet.")
|
503 |
+
|
504 |
+
layers.extend(
|
505 |
+
[
|
506 |
+
Conv2d(
|
507 |
+
out_dim,
|
508 |
+
out_channels,
|
509 |
+
kernel_size=1,
|
510 |
+
bias=use_bias,
|
511 |
+
norm=get_norm(norm, out_channels),
|
512 |
+
),
|
513 |
+
Conv2d(
|
514 |
+
out_channels,
|
515 |
+
out_channels,
|
516 |
+
kernel_size=3,
|
517 |
+
padding=1,
|
518 |
+
bias=use_bias,
|
519 |
+
norm=get_norm(norm, out_channels),
|
520 |
+
),
|
521 |
+
]
|
522 |
+
)
|
523 |
+
layers = nn.Sequential(*layers)
|
524 |
+
|
525 |
+
stage = int(math.log2(strides[idx]))
|
526 |
+
self.add_module(f"simfp_{stage}", layers)
|
527 |
+
self.stages.append(layers)
|
528 |
+
|
529 |
+
self.net = net
|
530 |
+
self.in_feature = in_feature
|
531 |
+
self.top_block = top_block
|
532 |
+
# Return feature names are "p<stage>", like ["p2", "p3", ..., "p6"]
|
533 |
+
self._out_feature_strides = {"p{}".format(int(math.log2(s))): s for s in strides}
|
534 |
+
# top block output feature maps.
|
535 |
+
if self.top_block is not None:
|
536 |
+
for s in range(stage, stage + self.top_block.num_levels):
|
537 |
+
self._out_feature_strides["p{}".format(s + 1)] = 2 ** (s + 1)
|
538 |
+
|
539 |
+
self._out_features = list(self._out_feature_strides.keys())
|
540 |
+
self._out_feature_channels = {k: out_channels for k in self._out_features}
|
541 |
+
self._size_divisibility = strides[-1]
|
542 |
+
self._square_pad = square_pad
|
543 |
+
|
544 |
+
@property
|
545 |
+
def padding_constraints(self):
|
546 |
+
return {
|
547 |
+
"size_divisiblity": self._size_divisibility,
|
548 |
+
"square_size": self._square_pad,
|
549 |
+
}
|
550 |
+
|
551 |
+
def forward(self, x):
|
552 |
+
"""
|
553 |
+
Args:
|
554 |
+
x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``.
|
555 |
+
|
556 |
+
Returns:
|
557 |
+
dict[str->Tensor]:
|
558 |
+
mapping from feature map name to pyramid feature map tensor
|
559 |
+
in high to low resolution order. Returned feature names follow the FPN
|
560 |
+
convention: "p<stage>", where stage has stride = 2 ** stage e.g.,
|
561 |
+
["p2", "p3", ..., "p6"].
|
562 |
+
"""
|
563 |
+
bottom_up_features = self.net(x)
|
564 |
+
features = bottom_up_features[self.in_feature]
|
565 |
+
results = []
|
566 |
+
|
567 |
+
for stage in self.stages:
|
568 |
+
results.append(stage(features))
|
569 |
+
|
570 |
+
if self.top_block is not None:
|
571 |
+
if self.top_block.in_feature in bottom_up_features:
|
572 |
+
top_block_in_feature = bottom_up_features[self.top_block.in_feature]
|
573 |
+
else:
|
574 |
+
top_block_in_feature = results[self._out_features.index(self.top_block.in_feature)]
|
575 |
+
results.extend(self.top_block(top_block_in_feature))
|
576 |
+
assert len(self._out_features) == len(results)
|
577 |
+
return {f: res for f, res in zip(self._out_features, results)}
|
578 |
+
|
579 |
+
|
580 |
+
def get_vit_lr_decay_rate(name, lr_decay_rate=1.0, num_layers=12):
|
581 |
+
"""
|
582 |
+
Calculate lr decay rate for different ViT blocks.
|
583 |
+
Args:
|
584 |
+
name (string): parameter name.
|
585 |
+
lr_decay_rate (float): base lr decay rate.
|
586 |
+
num_layers (int): number of ViT blocks.
|
587 |
+
|
588 |
+
Returns:
|
589 |
+
lr decay rate for the given parameter.
|
590 |
+
"""
|
591 |
+
layer_id = num_layers + 1
|
592 |
+
if name.startswith("backbone"):
|
593 |
+
if ".pos_embed" in name or ".patch_embed" in name:
|
594 |
+
layer_id = 0
|
595 |
+
elif ".blocks." in name and ".residual." not in name:
|
596 |
+
layer_id = int(name[name.find(".blocks.") :].split(".")[2]) + 1
|
597 |
+
|
598 |
+
return lr_decay_rate ** (num_layers + 1 - layer_id)
|
GLEE/glee/backbone/eva02.py
ADDED
@@ -0,0 +1,647 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# EVA02
|
3 |
+
# --------------------------------------------------------
|
4 |
+
import logging
|
5 |
+
import math
|
6 |
+
from functools import partial
|
7 |
+
|
8 |
+
import fvcore.nn.weight_init as weight_init
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.nn.functional as F
|
12 |
+
|
13 |
+
from detectron2.layers import CNNBlockBase, Conv2d, get_norm
|
14 |
+
from detectron2.modeling.backbone.fpn import _assert_strides_are_log2_contiguous
|
15 |
+
|
16 |
+
from detectron2.modeling.backbone import Backbone
|
17 |
+
from timm.models.layers import DropPath, Mlp, trunc_normal_
|
18 |
+
from detectron2.modeling import BACKBONE_REGISTRY, Backbone, ShapeSpec
|
19 |
+
|
20 |
+
|
21 |
+
from .eva_02_utils import (
|
22 |
+
PatchEmbed,
|
23 |
+
add_decomposed_rel_pos,
|
24 |
+
get_abs_pos,
|
25 |
+
window_partition,
|
26 |
+
window_unpartition,
|
27 |
+
VisionRotaryEmbeddingFast,
|
28 |
+
)
|
29 |
+
from detectron2.modeling.backbone.fpn import LastLevelMaxPool
|
30 |
+
|
31 |
+
|
32 |
+
try:
|
33 |
+
import xformers.ops as xops
|
34 |
+
HAS_XFORMER=True
|
35 |
+
except:
|
36 |
+
HAS_XFORMER=False
|
37 |
+
pass
|
38 |
+
|
39 |
+
try:
|
40 |
+
from apex.normalization import FusedLayerNorm
|
41 |
+
except:
|
42 |
+
pass
|
43 |
+
|
44 |
+
logger = logging.getLogger(__name__)
|
45 |
+
|
46 |
+
|
47 |
+
|
48 |
+
__all__ = ["EVA02_ViT", "SimpleFeaturePyramid", "get_vit_lr_decay_rate"]
|
49 |
+
|
50 |
+
|
51 |
+
|
52 |
+
class SwiGLU(nn.Module):
|
53 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.SiLU, drop=0.,
|
54 |
+
norm_layer=nn.LayerNorm, subln=False
|
55 |
+
):
|
56 |
+
super().__init__()
|
57 |
+
out_features = out_features or in_features
|
58 |
+
hidden_features = hidden_features or in_features
|
59 |
+
|
60 |
+
self.w1 = nn.Linear(in_features, hidden_features)
|
61 |
+
self.w2 = nn.Linear(in_features, hidden_features)
|
62 |
+
|
63 |
+
self.act = act_layer()
|
64 |
+
self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity()
|
65 |
+
self.w3 = nn.Linear(hidden_features, out_features)
|
66 |
+
|
67 |
+
self.drop = nn.Dropout(drop)
|
68 |
+
|
69 |
+
def forward(self, x):
|
70 |
+
x1 = self.w1(x)
|
71 |
+
x2 = self.w2(x)
|
72 |
+
hidden = self.act(x1) * x2
|
73 |
+
x = self.ffn_ln(hidden)
|
74 |
+
x = self.w3(x)
|
75 |
+
x = self.drop(x)
|
76 |
+
return x
|
77 |
+
|
78 |
+
|
79 |
+
class Attention(nn.Module):
|
80 |
+
def __init__(
|
81 |
+
self,
|
82 |
+
dim,
|
83 |
+
num_heads=8,
|
84 |
+
qkv_bias=True,
|
85 |
+
qk_scale=None,
|
86 |
+
attn_head_dim=None,
|
87 |
+
rope=None,
|
88 |
+
xattn=True,
|
89 |
+
):
|
90 |
+
super().__init__()
|
91 |
+
self.num_heads = num_heads
|
92 |
+
head_dim = dim // num_heads
|
93 |
+
if attn_head_dim is not None:
|
94 |
+
head_dim = attn_head_dim
|
95 |
+
all_head_dim = head_dim * self.num_heads
|
96 |
+
self.scale = qk_scale or head_dim ** -0.5
|
97 |
+
|
98 |
+
self.q_proj = nn.Linear(dim, all_head_dim, bias=False)
|
99 |
+
self.k_proj = nn.Linear(dim, all_head_dim, bias=False)
|
100 |
+
self.v_proj = nn.Linear(dim, all_head_dim, bias=False)
|
101 |
+
|
102 |
+
if qkv_bias:
|
103 |
+
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
|
104 |
+
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
|
105 |
+
else:
|
106 |
+
self.q_bias = None
|
107 |
+
self.v_bias = None
|
108 |
+
|
109 |
+
self.rope = rope
|
110 |
+
self.xattn = xattn
|
111 |
+
self.proj = nn.Linear(all_head_dim, dim)
|
112 |
+
if not HAS_XFORMER:
|
113 |
+
self.xattn = False
|
114 |
+
|
115 |
+
def forward(self, x):
|
116 |
+
B, H, W, C = x.shape
|
117 |
+
x = x.view(B, -1, C)
|
118 |
+
N = H * W
|
119 |
+
|
120 |
+
q = F.linear(input=x, weight=self.q_proj.weight, bias=self.q_bias)
|
121 |
+
k = F.linear(input=x, weight=self.k_proj.weight, bias=None)
|
122 |
+
v = F.linear(input=x, weight=self.v_proj.weight, bias=self.v_bias)
|
123 |
+
|
124 |
+
q = q.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3) # B, num_heads, N, C
|
125 |
+
k = k.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
|
126 |
+
v = v.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
|
127 |
+
|
128 |
+
## rope
|
129 |
+
q = self.rope(q).type_as(v)
|
130 |
+
k = self.rope(k).type_as(v)
|
131 |
+
|
132 |
+
if self.xattn:
|
133 |
+
q = q.permute(0, 2, 1, 3) # B, num_heads, N, C -> B, N, num_heads, C
|
134 |
+
k = k.permute(0, 2, 1, 3)
|
135 |
+
v = v.permute(0, 2, 1, 3)
|
136 |
+
|
137 |
+
x = xops.memory_efficient_attention(q, k, v)
|
138 |
+
x = x.reshape(B, N, -1)
|
139 |
+
else:
|
140 |
+
q = q * self.scale
|
141 |
+
attn = (q @ k.transpose(-2, -1))
|
142 |
+
attn = attn.softmax(dim=-1).type_as(x)
|
143 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
|
144 |
+
|
145 |
+
x = self.proj(x)
|
146 |
+
x = x.view(B, H, W, C)
|
147 |
+
|
148 |
+
return x
|
149 |
+
|
150 |
+
|
151 |
+
class ResBottleneckBlock(CNNBlockBase):
|
152 |
+
"""
|
153 |
+
The standard bottleneck residual block without the last activation layer.
|
154 |
+
It contains 3 conv layers with kernels 1x1, 3x3, 1x1.
|
155 |
+
"""
|
156 |
+
|
157 |
+
def __init__(
|
158 |
+
self,
|
159 |
+
in_channels,
|
160 |
+
out_channels,
|
161 |
+
bottleneck_channels,
|
162 |
+
norm="LN",
|
163 |
+
act_layer=nn.GELU,
|
164 |
+
):
|
165 |
+
"""
|
166 |
+
Args:
|
167 |
+
in_channels (int): Number of input channels.
|
168 |
+
out_channels (int): Number of output channels.
|
169 |
+
bottleneck_channels (int): number of output channels for the 3x3
|
170 |
+
"bottleneck" conv layers.
|
171 |
+
norm (str or callable): normalization for all conv layers.
|
172 |
+
See :func:`layers.get_norm` for supported format.
|
173 |
+
act_layer (callable): activation for all conv layers.
|
174 |
+
"""
|
175 |
+
super().__init__(in_channels, out_channels, 1)
|
176 |
+
|
177 |
+
self.conv1 = Conv2d(in_channels, bottleneck_channels, 1, bias=False)
|
178 |
+
self.norm1 = get_norm(norm, bottleneck_channels)
|
179 |
+
self.act1 = act_layer()
|
180 |
+
|
181 |
+
self.conv2 = Conv2d(
|
182 |
+
bottleneck_channels,
|
183 |
+
bottleneck_channels,
|
184 |
+
3,
|
185 |
+
padding=1,
|
186 |
+
bias=False,
|
187 |
+
)
|
188 |
+
self.norm2 = get_norm(norm, bottleneck_channels)
|
189 |
+
self.act2 = act_layer()
|
190 |
+
|
191 |
+
self.conv3 = Conv2d(bottleneck_channels, out_channels, 1, bias=False)
|
192 |
+
self.norm3 = get_norm(norm, out_channels)
|
193 |
+
|
194 |
+
for layer in [self.conv1, self.conv2, self.conv3]:
|
195 |
+
weight_init.c2_msra_fill(layer)
|
196 |
+
for layer in [self.norm1, self.norm2]:
|
197 |
+
layer.weight.data.fill_(1.0)
|
198 |
+
layer.bias.data.zero_()
|
199 |
+
# zero init last norm layer.
|
200 |
+
self.norm3.weight.data.zero_()
|
201 |
+
self.norm3.bias.data.zero_()
|
202 |
+
|
203 |
+
def forward(self, x):
|
204 |
+
out = x
|
205 |
+
for layer in self.children():
|
206 |
+
out = layer(out)
|
207 |
+
|
208 |
+
out = x + out
|
209 |
+
return out
|
210 |
+
|
211 |
+
|
212 |
+
class Block(nn.Module):
|
213 |
+
"""Transformer blocks with support of window attention and residual propagation blocks"""
|
214 |
+
|
215 |
+
def __init__(
|
216 |
+
self,
|
217 |
+
dim,
|
218 |
+
num_heads,
|
219 |
+
mlp_ratio=4*2/3,
|
220 |
+
qkv_bias=True,
|
221 |
+
drop_path=0.0,
|
222 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
223 |
+
window_size=0,
|
224 |
+
use_residual_block=False,
|
225 |
+
rope=None,
|
226 |
+
xattn=True,
|
227 |
+
):
|
228 |
+
"""
|
229 |
+
Args:
|
230 |
+
dim (int): Number of input channels.
|
231 |
+
num_heads (int): Number of attention heads in each ViT block.
|
232 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
233 |
+
qkv_bias (bool): If True, add a learnable bias to query, key, value.
|
234 |
+
drop_path (float): Stochastic depth rate.
|
235 |
+
norm_layer (nn.Module): Normalization layer.
|
236 |
+
act_layer (nn.Module): Activation layer.
|
237 |
+
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
|
238 |
+
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
|
239 |
+
window_size (int): Window size for window attention blocks. If it equals 0, then not
|
240 |
+
use window attention.
|
241 |
+
use_residual_block (bool): If True, use a residual block after the MLP block.
|
242 |
+
input_size (int or None): Input resolution for calculating the relative positional
|
243 |
+
parameter size.
|
244 |
+
"""
|
245 |
+
super().__init__()
|
246 |
+
self.norm1 = norm_layer(dim)
|
247 |
+
self.attn = Attention(
|
248 |
+
dim,
|
249 |
+
num_heads=num_heads,
|
250 |
+
qkv_bias=qkv_bias,
|
251 |
+
rope=rope,
|
252 |
+
xattn=xattn,
|
253 |
+
)
|
254 |
+
|
255 |
+
from timm.models.layers import DropPath
|
256 |
+
|
257 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
258 |
+
self.norm2 = norm_layer(dim)
|
259 |
+
self.mlp = SwiGLU(
|
260 |
+
in_features=dim,
|
261 |
+
hidden_features=int(dim * mlp_ratio),
|
262 |
+
subln=True,
|
263 |
+
norm_layer=norm_layer,
|
264 |
+
)
|
265 |
+
|
266 |
+
self.window_size = window_size
|
267 |
+
|
268 |
+
self.use_residual_block = use_residual_block
|
269 |
+
if use_residual_block:
|
270 |
+
# Use a residual block with bottleneck channel as dim // 2
|
271 |
+
self.residual = ResBottleneckBlock(
|
272 |
+
in_channels=dim,
|
273 |
+
out_channels=dim,
|
274 |
+
bottleneck_channels=dim // 2,
|
275 |
+
norm="LN",
|
276 |
+
)
|
277 |
+
|
278 |
+
def forward(self, x):
|
279 |
+
shortcut = x
|
280 |
+
x = self.norm1(x)
|
281 |
+
|
282 |
+
# Window partition
|
283 |
+
if self.window_size > 0:
|
284 |
+
H, W = x.shape[1], x.shape[2]
|
285 |
+
x, pad_hw = window_partition(x, self.window_size)
|
286 |
+
|
287 |
+
x = self.attn(x)
|
288 |
+
|
289 |
+
# Reverse window partition
|
290 |
+
if self.window_size > 0:
|
291 |
+
x = window_unpartition(x, self.window_size, pad_hw, (H, W))
|
292 |
+
|
293 |
+
x = shortcut + self.drop_path(x)
|
294 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
295 |
+
|
296 |
+
if self.use_residual_block:
|
297 |
+
x = self.residual(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
|
298 |
+
|
299 |
+
return x
|
300 |
+
|
301 |
+
|
302 |
+
class EVA02_ViT(Backbone):
|
303 |
+
"""
|
304 |
+
This module implements Vision Transformer (ViT) backbone in :paper:`vitdet`.
|
305 |
+
"Exploring Plain Vision Transformer Backbones for Object Detection",
|
306 |
+
https://arxiv.org/abs/2203.16527
|
307 |
+
"""
|
308 |
+
|
309 |
+
def __init__(
|
310 |
+
self,
|
311 |
+
img_size=1024,
|
312 |
+
patch_size=16,
|
313 |
+
in_chans=3,
|
314 |
+
embed_dim=768,
|
315 |
+
depth=12,
|
316 |
+
num_heads=12,
|
317 |
+
mlp_ratio=4*2/3,
|
318 |
+
qkv_bias=True,
|
319 |
+
drop_path_rate=0.0,
|
320 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
321 |
+
act_layer=nn.GELU,
|
322 |
+
use_abs_pos=True,
|
323 |
+
use_rel_pos=False,
|
324 |
+
rope=True,
|
325 |
+
pt_hw_seq_len=16,
|
326 |
+
intp_freq=True,
|
327 |
+
window_size=0,
|
328 |
+
window_block_indexes=(),
|
329 |
+
residual_block_indexes=(),
|
330 |
+
use_act_checkpoint=False,
|
331 |
+
pretrain_img_size=224,
|
332 |
+
pretrain_use_cls_token=True,
|
333 |
+
out_feature="last_feat",
|
334 |
+
xattn=True,
|
335 |
+
):
|
336 |
+
"""
|
337 |
+
Args:
|
338 |
+
img_size (int): Input image size.
|
339 |
+
patch_size (int): Patch size.
|
340 |
+
in_chans (int): Number of input image channels.
|
341 |
+
embed_dim (int): Patch embedding dimension.
|
342 |
+
depth (int): Depth of ViT.
|
343 |
+
num_heads (int): Number of attention heads in each ViT block.
|
344 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
345 |
+
qkv_bias (bool): If True, add a learnable bias to query, key, value.
|
346 |
+
drop_path_rate (float): Stochastic depth rate.
|
347 |
+
norm_layer (nn.Module): Normalization layer.
|
348 |
+
act_layer (nn.Module): Activation layer.
|
349 |
+
use_abs_pos (bool): If True, use absolute positional embeddings.
|
350 |
+
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
|
351 |
+
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
|
352 |
+
window_size (int): Window size for window attention blocks.
|
353 |
+
window_block_indexes (list): Indexes for blocks using window attention.
|
354 |
+
residual_block_indexes (list): Indexes for blocks using conv propagation.
|
355 |
+
use_act_checkpoint (bool): If True, use activation checkpointing.
|
356 |
+
pretrain_img_size (int): input image size for pretraining models.
|
357 |
+
pretrain_use_cls_token (bool): If True, pretrainig models use class token.
|
358 |
+
out_feature (str): name of the feature from the last block.
|
359 |
+
"""
|
360 |
+
super().__init__()
|
361 |
+
self.pretrain_use_cls_token = pretrain_use_cls_token
|
362 |
+
|
363 |
+
self.patch_embed = PatchEmbed(
|
364 |
+
kernel_size=(patch_size, patch_size),
|
365 |
+
stride=(patch_size, patch_size),
|
366 |
+
in_chans=in_chans,
|
367 |
+
embed_dim=embed_dim,
|
368 |
+
)
|
369 |
+
|
370 |
+
if use_abs_pos:
|
371 |
+
# Initialize absolute positional embedding with pretrain image size.
|
372 |
+
num_patches = (pretrain_img_size // patch_size) * (pretrain_img_size // patch_size)
|
373 |
+
num_positions = (num_patches + 1) if pretrain_use_cls_token else num_patches
|
374 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_positions, embed_dim))
|
375 |
+
else:
|
376 |
+
self.pos_embed = None
|
377 |
+
|
378 |
+
|
379 |
+
half_head_dim = embed_dim // num_heads // 2
|
380 |
+
hw_seq_len = img_size // patch_size
|
381 |
+
|
382 |
+
self.rope_win = VisionRotaryEmbeddingFast(
|
383 |
+
dim=half_head_dim,
|
384 |
+
pt_seq_len=pt_hw_seq_len,
|
385 |
+
ft_seq_len=window_size if intp_freq else None,
|
386 |
+
)
|
387 |
+
self.rope_glb = VisionRotaryEmbeddingFast(
|
388 |
+
dim=half_head_dim,
|
389 |
+
pt_seq_len=pt_hw_seq_len,
|
390 |
+
ft_seq_len=hw_seq_len if intp_freq else None,
|
391 |
+
)
|
392 |
+
|
393 |
+
# stochastic depth decay rule
|
394 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
|
395 |
+
|
396 |
+
self.blocks = nn.ModuleList()
|
397 |
+
for i in range(depth):
|
398 |
+
block = Block(
|
399 |
+
dim=embed_dim,
|
400 |
+
num_heads=num_heads,
|
401 |
+
mlp_ratio=mlp_ratio,
|
402 |
+
qkv_bias=qkv_bias,
|
403 |
+
drop_path=dpr[i],
|
404 |
+
norm_layer=norm_layer,
|
405 |
+
window_size=window_size if i in window_block_indexes else 0,
|
406 |
+
use_residual_block=i in residual_block_indexes,
|
407 |
+
rope=self.rope_win if i in window_block_indexes else self.rope_glb,
|
408 |
+
xattn=xattn
|
409 |
+
)
|
410 |
+
if use_act_checkpoint:
|
411 |
+
# TODO: use torch.utils.checkpoint
|
412 |
+
from fairscale.nn.checkpoint import checkpoint_wrapper
|
413 |
+
|
414 |
+
block = checkpoint_wrapper(block)
|
415 |
+
self.blocks.append(block)
|
416 |
+
|
417 |
+
self._out_feature_channels = {out_feature: embed_dim}
|
418 |
+
self._out_feature_strides = {out_feature: patch_size}
|
419 |
+
self._out_features = [out_feature]
|
420 |
+
|
421 |
+
if self.pos_embed is not None:
|
422 |
+
nn.init.trunc_normal_(self.pos_embed, std=0.02)
|
423 |
+
|
424 |
+
self.apply(self._init_weights)
|
425 |
+
|
426 |
+
def _init_weights(self, m):
|
427 |
+
if isinstance(m, nn.Linear):
|
428 |
+
nn.init.trunc_normal_(m.weight, std=0.02)
|
429 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
430 |
+
nn.init.constant_(m.bias, 0)
|
431 |
+
elif isinstance(m, nn.LayerNorm):
|
432 |
+
nn.init.constant_(m.bias, 0)
|
433 |
+
nn.init.constant_(m.weight, 1.0)
|
434 |
+
|
435 |
+
def forward(self, x):
|
436 |
+
x = self.patch_embed(x)
|
437 |
+
if self.pos_embed is not None:
|
438 |
+
x = x + get_abs_pos(
|
439 |
+
self.pos_embed, self.pretrain_use_cls_token, (x.shape[1], x.shape[2])
|
440 |
+
)
|
441 |
+
|
442 |
+
for blk in self.blocks:
|
443 |
+
x = blk(x)
|
444 |
+
|
445 |
+
outputs = {self._out_features[0]: x.permute(0, 3, 1, 2)}
|
446 |
+
return outputs
|
447 |
+
|
448 |
+
|
449 |
+
class SimpleFeaturePyramid(Backbone):
|
450 |
+
"""
|
451 |
+
This module implements SimpleFeaturePyramid in :paper:`vitdet`.
|
452 |
+
It creates pyramid features built on top of the input feature map.
|
453 |
+
"""
|
454 |
+
|
455 |
+
def __init__(
|
456 |
+
self,
|
457 |
+
net,
|
458 |
+
in_feature,
|
459 |
+
out_channels,
|
460 |
+
scale_factors,
|
461 |
+
top_block=None,
|
462 |
+
norm="LN",
|
463 |
+
square_pad=0,
|
464 |
+
):
|
465 |
+
"""
|
466 |
+
Args:
|
467 |
+
net (Backbone): module representing the subnetwork backbone.
|
468 |
+
Must be a subclass of :class:`Backbone`.
|
469 |
+
in_feature (str): names of the input feature maps coming
|
470 |
+
from the net.
|
471 |
+
out_channels (int): number of channels in the output feature maps.
|
472 |
+
scale_factors (list[float]): list of scaling factors to upsample or downsample
|
473 |
+
the input features for creating pyramid features.
|
474 |
+
top_block (nn.Module or None): if provided, an extra operation will
|
475 |
+
be performed on the output of the last (smallest resolution)
|
476 |
+
pyramid output, and the result will extend the result list. The top_block
|
477 |
+
further downsamples the feature map. It must have an attribute
|
478 |
+
"num_levels", meaning the number of extra pyramid levels added by
|
479 |
+
this block, and "in_feature", which is a string representing
|
480 |
+
its input feature (e.g., p5).
|
481 |
+
norm (str): the normalization to use.
|
482 |
+
square_pad (int): If > 0, require input images to be padded to specific square size.
|
483 |
+
"""
|
484 |
+
super(SimpleFeaturePyramid, self).__init__()
|
485 |
+
assert isinstance(net, Backbone)
|
486 |
+
|
487 |
+
self.scale_factors = scale_factors
|
488 |
+
|
489 |
+
input_shapes = net.output_shape()
|
490 |
+
strides = [int(input_shapes[in_feature].stride / scale) for scale in scale_factors]
|
491 |
+
_assert_strides_are_log2_contiguous(strides)
|
492 |
+
|
493 |
+
dim = input_shapes[in_feature].channels
|
494 |
+
self.stages = []
|
495 |
+
use_bias = norm == ""
|
496 |
+
for idx, scale in enumerate(scale_factors):
|
497 |
+
out_dim = dim
|
498 |
+
if scale == 4.0:
|
499 |
+
layers = [
|
500 |
+
nn.ConvTranspose2d(dim, dim // 2, kernel_size=2, stride=2),
|
501 |
+
get_norm(norm, dim // 2),
|
502 |
+
nn.GELU(),
|
503 |
+
nn.ConvTranspose2d(dim // 2, dim // 4, kernel_size=2, stride=2),
|
504 |
+
]
|
505 |
+
out_dim = dim // 4
|
506 |
+
elif scale == 2.0:
|
507 |
+
layers = [nn.ConvTranspose2d(dim, dim // 2, kernel_size=2, stride=2)]
|
508 |
+
out_dim = dim // 2
|
509 |
+
elif scale == 1.0:
|
510 |
+
layers = []
|
511 |
+
elif scale == 0.5:
|
512 |
+
layers = [nn.MaxPool2d(kernel_size=2, stride=2)]
|
513 |
+
else:
|
514 |
+
raise NotImplementedError(f"scale_factor={scale} is not supported yet.")
|
515 |
+
|
516 |
+
layers.extend(
|
517 |
+
[
|
518 |
+
Conv2d(
|
519 |
+
out_dim,
|
520 |
+
out_channels,
|
521 |
+
kernel_size=1,
|
522 |
+
bias=use_bias,
|
523 |
+
norm=get_norm(norm, out_channels),
|
524 |
+
),
|
525 |
+
Conv2d(
|
526 |
+
out_channels,
|
527 |
+
out_channels,
|
528 |
+
kernel_size=3,
|
529 |
+
padding=1,
|
530 |
+
bias=use_bias,
|
531 |
+
norm=get_norm(norm, out_channels),
|
532 |
+
),
|
533 |
+
]
|
534 |
+
)
|
535 |
+
layers = nn.Sequential(*layers)
|
536 |
+
|
537 |
+
stage = int(math.log2(strides[idx]))
|
538 |
+
self.add_module(f"simfp_{stage}", layers)
|
539 |
+
self.stages.append(layers)
|
540 |
+
|
541 |
+
self.net = net
|
542 |
+
self.in_feature = in_feature
|
543 |
+
self.top_block = top_block
|
544 |
+
# Return feature names are "p<stage>", like ["p2", "p3", ..., "p6"]
|
545 |
+
self._out_feature_strides = {"p{}".format(int(math.log2(s))): s for s in strides}
|
546 |
+
# top block output feature maps.
|
547 |
+
if self.top_block is not None:
|
548 |
+
for s in range(stage, stage + self.top_block.num_levels):
|
549 |
+
self._out_feature_strides["p{}".format(s + 1)] = 2 ** (s + 1)
|
550 |
+
|
551 |
+
self._out_features = list(self._out_feature_strides.keys())
|
552 |
+
self._out_feature_channels = {k: out_channels for k in self._out_features}
|
553 |
+
self._size_divisibility = strides[-1]
|
554 |
+
self._square_pad = square_pad
|
555 |
+
|
556 |
+
@property
|
557 |
+
def padding_constraints(self):
|
558 |
+
return {
|
559 |
+
"size_divisiblity": self._size_divisibility,
|
560 |
+
"square_size": self._square_pad,
|
561 |
+
}
|
562 |
+
|
563 |
+
def forward(self, x):
|
564 |
+
"""
|
565 |
+
Args:
|
566 |
+
x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``.
|
567 |
+
|
568 |
+
Returns:
|
569 |
+
dict[str->Tensor]:
|
570 |
+
mapping from feature map name to pyramid feature map tensor
|
571 |
+
in high to low resolution order. Returned feature names follow the FPN
|
572 |
+
convention: "p<stage>", where stage has stride = 2 ** stage e.g.,
|
573 |
+
["p2", "p3", ..., "p6"].
|
574 |
+
"""
|
575 |
+
bottom_up_features = self.net(x)
|
576 |
+
features = bottom_up_features[self.in_feature]
|
577 |
+
results = []
|
578 |
+
|
579 |
+
for stage in self.stages:
|
580 |
+
results.append(stage(features))
|
581 |
+
|
582 |
+
if self.top_block is not None:
|
583 |
+
if self.top_block.in_feature in bottom_up_features:
|
584 |
+
top_block_in_feature = bottom_up_features[self.top_block.in_feature]
|
585 |
+
else:
|
586 |
+
top_block_in_feature = results[self._out_features.index(self.top_block.in_feature)]
|
587 |
+
results.extend(self.top_block(top_block_in_feature))
|
588 |
+
assert len(self._out_features) == len(results)
|
589 |
+
return {f: res for f, res in zip(self._out_features, results)}
|
590 |
+
|
591 |
+
|
592 |
+
|
593 |
+
@BACKBONE_REGISTRY.register()
|
594 |
+
class D2_EVA02(SimpleFeaturePyramid):
|
595 |
+
def __init__(self, cfg, input_shape):
|
596 |
+
|
597 |
+
super().__init__(
|
598 |
+
|
599 |
+
net = EVA02_ViT(
|
600 |
+
img_size= cfg.MODEL.EVA02.IMAGE_SIZE,
|
601 |
+
patch_size=cfg.MODEL.EVA02.PATCH_SIZE,
|
602 |
+
window_size= cfg.MODEL.EVA02.WINDOW_SIZE,
|
603 |
+
embed_dim= cfg.MODEL.EVA02.DMBED_DIM,
|
604 |
+
depth= cfg.MODEL.EVA02.DEPTH,
|
605 |
+
num_heads= cfg.MODEL.EVA02.NUM_HEADS ,
|
606 |
+
drop_path_rate= cfg.MODEL.EVA02.DROP_PATH_RATE,
|
607 |
+
mlp_ratio= cfg.MODEL.EVA02.MLP_RATIO,
|
608 |
+
# qkv_bias=True,
|
609 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
610 |
+
window_block_indexes= cfg.MODEL.EVA02.WINDOW_BLOCK_INDEXES,
|
611 |
+
# residual_block_indexes=[],
|
612 |
+
# use_rel_pos=False,
|
613 |
+
use_act_checkpoint = cfg.MODEL.EVA02.CHECKPOINT,
|
614 |
+
out_feature="last_feat",
|
615 |
+
# intp_freq=True,
|
616 |
+
),
|
617 |
+
in_feature = "last_feat",
|
618 |
+
out_channels=256,
|
619 |
+
scale_factors=(2.0, 1.0, 0.5), # (4.0, 2.0, 1.0, 0.5) in ViTDet
|
620 |
+
top_block=LastLevelMaxPool(),
|
621 |
+
norm="LN",
|
622 |
+
square_pad=cfg.MODEL.EVA02.IMAGE_SIZE,
|
623 |
+
|
624 |
+
)
|
625 |
+
|
626 |
+
pretrained_weight = cfg.MODEL.EVA02.PRETRAINED_WEIGHT
|
627 |
+
if pretrained_weight:
|
628 |
+
checkpoint = torch.load(pretrained_weight, map_location='cpu')
|
629 |
+
print(f'\nload pretrain weight from {pretrained_weight} \n')
|
630 |
+
|
631 |
+
self.load_state_dict(checkpoint['model'], strict=False)
|
632 |
+
|
633 |
+
|
634 |
+
|
635 |
+
def output_shape(self):
|
636 |
+
return {
|
637 |
+
name: ShapeSpec(
|
638 |
+
channels=self._out_feature_channels[name], stride=self._out_feature_strides[name]
|
639 |
+
)
|
640 |
+
for name in self._out_features
|
641 |
+
}
|
642 |
+
|
643 |
+
@property
|
644 |
+
def size_divisibility(self):
|
645 |
+
return 32
|
646 |
+
|
647 |
+
|
GLEE/glee/backbone/eva_01_utils.py
ADDED
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
2 |
+
import math
|
3 |
+
import numpy as np
|
4 |
+
from scipy import interpolate
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
|
9 |
+
__all__ = [
|
10 |
+
"window_partition",
|
11 |
+
"window_unpartition",
|
12 |
+
"add_decomposed_rel_pos",
|
13 |
+
"get_abs_pos",
|
14 |
+
"PatchEmbed",
|
15 |
+
]
|
16 |
+
|
17 |
+
|
18 |
+
def window_partition(x, window_size):
|
19 |
+
"""
|
20 |
+
Partition into non-overlapping windows with padding if needed.
|
21 |
+
Args:
|
22 |
+
x (tensor): input tokens with [B, H, W, C].
|
23 |
+
window_size (int): window size.
|
24 |
+
|
25 |
+
Returns:
|
26 |
+
windows: windows after partition with [B * num_windows, window_size, window_size, C].
|
27 |
+
(Hp, Wp): padded height and width before partition
|
28 |
+
"""
|
29 |
+
B, H, W, C = x.shape
|
30 |
+
|
31 |
+
pad_h = (window_size - H % window_size) % window_size
|
32 |
+
pad_w = (window_size - W % window_size) % window_size
|
33 |
+
if pad_h > 0 or pad_w > 0:
|
34 |
+
x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
|
35 |
+
Hp, Wp = H + pad_h, W + pad_w
|
36 |
+
|
37 |
+
x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
|
38 |
+
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
39 |
+
return windows, (Hp, Wp)
|
40 |
+
|
41 |
+
|
42 |
+
def window_unpartition(windows, window_size, pad_hw, hw):
|
43 |
+
"""
|
44 |
+
Window unpartition into original sequences and removing padding.
|
45 |
+
Args:
|
46 |
+
x (tensor): input tokens with [B * num_windows, window_size, window_size, C].
|
47 |
+
window_size (int): window size.
|
48 |
+
pad_hw (Tuple): padded height and width (Hp, Wp).
|
49 |
+
hw (Tuple): original height and width (H, W) before padding.
|
50 |
+
|
51 |
+
Returns:
|
52 |
+
x: unpartitioned sequences with [B, H, W, C].
|
53 |
+
"""
|
54 |
+
Hp, Wp = pad_hw
|
55 |
+
H, W = hw
|
56 |
+
B = windows.shape[0] // (Hp * Wp // window_size // window_size)
|
57 |
+
x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
|
58 |
+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
|
59 |
+
|
60 |
+
if Hp > H or Wp > W:
|
61 |
+
x = x[:, :H, :W, :].contiguous()
|
62 |
+
return x
|
63 |
+
|
64 |
+
|
65 |
+
def get_rel_pos(q_size, k_size, rel_pos, interp_type):
|
66 |
+
"""
|
67 |
+
Get relative positional embeddings according to the relative positions of
|
68 |
+
query and key sizes.
|
69 |
+
Args:
|
70 |
+
q_size (int): size of query q.
|
71 |
+
k_size (int): size of key k.
|
72 |
+
rel_pos (Tensor): relative position embeddings (L, C).
|
73 |
+
|
74 |
+
Returns:
|
75 |
+
Extracted positional embeddings according to relative positions.
|
76 |
+
"""
|
77 |
+
max_rel_dist = int(2 * max(q_size, k_size) - 1)
|
78 |
+
# Interpolate rel pos if needed.
|
79 |
+
if rel_pos.shape[0] != max_rel_dist:
|
80 |
+
if interp_type == "vitdet":
|
81 |
+
# the vitdet impl:
|
82 |
+
# https://github.com/facebookresearch/detectron2/blob/96c752ce821a3340e27edd51c28a00665dd32a30/detectron2/modeling/backbone/utils.py#L77.
|
83 |
+
|
84 |
+
rel_pos_resized = F.interpolate(
|
85 |
+
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
|
86 |
+
size=max_rel_dist,
|
87 |
+
mode="linear",
|
88 |
+
)
|
89 |
+
rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
|
90 |
+
elif interp_type == "beit":
|
91 |
+
# steal from beit https://github.com/microsoft/unilm/tree/master/beit
|
92 |
+
# modified by Yuxin Fang
|
93 |
+
|
94 |
+
src_size = rel_pos.shape[0]
|
95 |
+
dst_size = max_rel_dist
|
96 |
+
|
97 |
+
q = 1.0903078
|
98 |
+
dis = []
|
99 |
+
|
100 |
+
cur = 1
|
101 |
+
for i in range(src_size // 2):
|
102 |
+
dis.append(cur)
|
103 |
+
cur += q ** (i + 1)
|
104 |
+
|
105 |
+
r_ids = [-_ for _ in reversed(dis)]
|
106 |
+
x = r_ids + [0] + dis
|
107 |
+
t = dst_size // 2.0
|
108 |
+
dx = np.arange(-t, t + 0.1, 1.0)
|
109 |
+
|
110 |
+
all_rel_pos_bias = []
|
111 |
+
for i in range(rel_pos.shape[1]):
|
112 |
+
# a hack from https://github.com/baaivision/EVA/issues/8,
|
113 |
+
# could also be used in fine-tuning but the performance haven't been tested.
|
114 |
+
z = rel_pos[:, i].view(src_size).cpu().float().detach().numpy()
|
115 |
+
f = interpolate.interp1d(x, z, kind='cubic', fill_value="extrapolate")
|
116 |
+
all_rel_pos_bias.append(
|
117 |
+
torch.Tensor(f(dx)).contiguous().view(-1, 1).to(rel_pos.device))
|
118 |
+
rel_pos_resized = torch.cat(all_rel_pos_bias, dim=-1)
|
119 |
+
else:
|
120 |
+
raise NotImplementedError()
|
121 |
+
else:
|
122 |
+
rel_pos_resized = rel_pos
|
123 |
+
|
124 |
+
# Scale the coords with short length if shapes for q and k are different.
|
125 |
+
q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
|
126 |
+
k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
|
127 |
+
relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
|
128 |
+
|
129 |
+
return rel_pos_resized[relative_coords.long()]
|
130 |
+
|
131 |
+
|
132 |
+
def add_decomposed_rel_pos(attn, q, rel_pos_h, rel_pos_w, q_size, k_size, interp_type):
|
133 |
+
"""
|
134 |
+
Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
|
135 |
+
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
|
136 |
+
Args:
|
137 |
+
attn (Tensor): attention map.
|
138 |
+
q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
|
139 |
+
rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
|
140 |
+
rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
|
141 |
+
q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
|
142 |
+
k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
|
143 |
+
|
144 |
+
Returns:
|
145 |
+
attn (Tensor): attention map with added relative positional embeddings.
|
146 |
+
"""
|
147 |
+
q_h, q_w = q_size
|
148 |
+
k_h, k_w = k_size
|
149 |
+
Rh = get_rel_pos(q_h, k_h, rel_pos_h, interp_type)
|
150 |
+
Rw = get_rel_pos(q_w, k_w, rel_pos_w, interp_type)
|
151 |
+
|
152 |
+
B, _, dim = q.shape
|
153 |
+
r_q = q.reshape(B, q_h, q_w, dim)
|
154 |
+
rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
|
155 |
+
rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
|
156 |
+
|
157 |
+
attn = (
|
158 |
+
attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
|
159 |
+
).view(B, q_h * q_w, k_h * k_w)
|
160 |
+
|
161 |
+
return attn
|
162 |
+
|
163 |
+
|
164 |
+
def get_abs_pos(abs_pos, has_cls_token, hw):
|
165 |
+
"""
|
166 |
+
Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token
|
167 |
+
dimension for the original embeddings.
|
168 |
+
Args:
|
169 |
+
abs_pos (Tensor): absolute positional embeddings with (1, num_position, C).
|
170 |
+
has_cls_token (bool): If true, has 1 embedding in abs_pos for cls token.
|
171 |
+
hw (Tuple): size of input image tokens.
|
172 |
+
|
173 |
+
Returns:
|
174 |
+
Absolute positional embeddings after processing with shape (1, H, W, C)
|
175 |
+
"""
|
176 |
+
h, w = hw
|
177 |
+
if has_cls_token:
|
178 |
+
abs_pos = abs_pos[:, 1:]
|
179 |
+
xy_num = abs_pos.shape[1]
|
180 |
+
size = int(math.sqrt(xy_num))
|
181 |
+
assert size * size == xy_num
|
182 |
+
|
183 |
+
if size != h or size != w:
|
184 |
+
new_abs_pos = F.interpolate(
|
185 |
+
abs_pos.reshape(1, size, size, -1).permute(0, 3, 1, 2),
|
186 |
+
size=(h, w),
|
187 |
+
mode="bicubic",
|
188 |
+
align_corners=False,
|
189 |
+
)
|
190 |
+
|
191 |
+
return new_abs_pos.permute(0, 2, 3, 1)
|
192 |
+
else:
|
193 |
+
return abs_pos.reshape(1, h, w, -1)
|
194 |
+
|
195 |
+
|
196 |
+
class PatchEmbed(nn.Module):
|
197 |
+
"""
|
198 |
+
Image to Patch Embedding.
|
199 |
+
"""
|
200 |
+
|
201 |
+
def __init__(
|
202 |
+
self, kernel_size=(16, 16), stride=(16, 16), padding=(0, 0), in_chans=3, embed_dim=768
|
203 |
+
):
|
204 |
+
"""
|
205 |
+
Args:
|
206 |
+
kernel_size (Tuple): kernel size of the projection layer.
|
207 |
+
stride (Tuple): stride of the projection layer.
|
208 |
+
padding (Tuple): padding size of the projection layer.
|
209 |
+
in_chans (int): Number of input image channels.
|
210 |
+
embed_dim (int): embed_dim (int): Patch embedding dimension.
|
211 |
+
"""
|
212 |
+
super().__init__()
|
213 |
+
|
214 |
+
self.proj = nn.Conv2d(
|
215 |
+
in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
|
216 |
+
)
|
217 |
+
|
218 |
+
def forward(self, x):
|
219 |
+
x = self.proj(x)
|
220 |
+
# B C H W -> B H W C
|
221 |
+
x = x.permute(0, 2, 3, 1)
|
222 |
+
return x
|
GLEE/glee/backbone/eva_02_utils.py
ADDED
@@ -0,0 +1,356 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
2 |
+
import math
|
3 |
+
import numpy as np
|
4 |
+
from scipy import interpolate
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
|
9 |
+
__all__ = [
|
10 |
+
"window_partition",
|
11 |
+
"window_unpartition",
|
12 |
+
"add_decomposed_rel_pos",
|
13 |
+
"get_abs_pos",
|
14 |
+
"PatchEmbed",
|
15 |
+
"VisionRotaryEmbeddingFast",
|
16 |
+
]
|
17 |
+
|
18 |
+
|
19 |
+
def window_partition(x, window_size):
|
20 |
+
"""
|
21 |
+
Partition into non-overlapping windows with padding if needed.
|
22 |
+
Args:
|
23 |
+
x (tensor): input tokens with [B, H, W, C].
|
24 |
+
window_size (int): window size.
|
25 |
+
|
26 |
+
Returns:
|
27 |
+
windows: windows after partition with [B * num_windows, window_size, window_size, C].
|
28 |
+
(Hp, Wp): padded height and width before partition
|
29 |
+
"""
|
30 |
+
B, H, W, C = x.shape
|
31 |
+
|
32 |
+
pad_h = (window_size - H % window_size) % window_size
|
33 |
+
pad_w = (window_size - W % window_size) % window_size
|
34 |
+
if pad_h > 0 or pad_w > 0:
|
35 |
+
x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
|
36 |
+
Hp, Wp = H + pad_h, W + pad_w
|
37 |
+
|
38 |
+
x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
|
39 |
+
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
40 |
+
return windows, (Hp, Wp)
|
41 |
+
|
42 |
+
|
43 |
+
def window_unpartition(windows, window_size, pad_hw, hw):
|
44 |
+
"""
|
45 |
+
Window unpartition into original sequences and removing padding.
|
46 |
+
Args:
|
47 |
+
x (tensor): input tokens with [B * num_windows, window_size, window_size, C].
|
48 |
+
window_size (int): window size.
|
49 |
+
pad_hw (Tuple): padded height and width (Hp, Wp).
|
50 |
+
hw (Tuple): original height and width (H, W) before padding.
|
51 |
+
|
52 |
+
Returns:
|
53 |
+
x: unpartitioned sequences with [B, H, W, C].
|
54 |
+
"""
|
55 |
+
Hp, Wp = pad_hw
|
56 |
+
H, W = hw
|
57 |
+
B = windows.shape[0] // (Hp * Wp // window_size // window_size)
|
58 |
+
x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
|
59 |
+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
|
60 |
+
|
61 |
+
if Hp > H or Wp > W:
|
62 |
+
x = x[:, :H, :W, :].contiguous()
|
63 |
+
return x
|
64 |
+
|
65 |
+
|
66 |
+
def get_rel_pos(q_size, k_size, rel_pos):
|
67 |
+
"""
|
68 |
+
Get relative positional embeddings according to the relative positions of
|
69 |
+
query and key sizes.
|
70 |
+
Args:
|
71 |
+
q_size (int): size of query q.
|
72 |
+
k_size (int): size of key k.
|
73 |
+
rel_pos (Tensor): relative position embeddings (L, C).
|
74 |
+
|
75 |
+
Returns:
|
76 |
+
Extracted positional embeddings according to relative positions.
|
77 |
+
"""
|
78 |
+
max_rel_dist = int(2 * max(q_size, k_size) - 1)
|
79 |
+
use_log_interpolation = True
|
80 |
+
|
81 |
+
# Interpolate rel pos if needed.
|
82 |
+
if rel_pos.shape[0] != max_rel_dist:
|
83 |
+
if not use_log_interpolation:
|
84 |
+
# Interpolate rel pos.
|
85 |
+
rel_pos_resized = F.interpolate(
|
86 |
+
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
|
87 |
+
size=max_rel_dist,
|
88 |
+
mode="linear",
|
89 |
+
)
|
90 |
+
rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
|
91 |
+
else:
|
92 |
+
src_size = rel_pos.shape[0]
|
93 |
+
dst_size = max_rel_dist
|
94 |
+
|
95 |
+
# q = 1.13492
|
96 |
+
q = 1.0903078
|
97 |
+
dis = []
|
98 |
+
|
99 |
+
cur = 1
|
100 |
+
for i in range(src_size // 2):
|
101 |
+
dis.append(cur)
|
102 |
+
cur += q ** (i + 1)
|
103 |
+
|
104 |
+
r_ids = [-_ for _ in reversed(dis)]
|
105 |
+
x = r_ids + [0] + dis
|
106 |
+
t = dst_size // 2.0
|
107 |
+
dx = np.arange(-t, t + 0.1, 1.0)
|
108 |
+
# print("x = %s" % str(x))
|
109 |
+
# print("dx = %s" % str(dx))
|
110 |
+
all_rel_pos_bias = []
|
111 |
+
for i in range(rel_pos.shape[1]):
|
112 |
+
z = rel_pos[:, i].view(src_size).cpu().float().numpy()
|
113 |
+
f = interpolate.interp1d(x, z, kind='cubic', fill_value="extrapolate")
|
114 |
+
all_rel_pos_bias.append(
|
115 |
+
torch.Tensor(f(dx)).contiguous().view(-1, 1).to(rel_pos.device))
|
116 |
+
rel_pos_resized = torch.cat(all_rel_pos_bias, dim=-1)
|
117 |
+
else:
|
118 |
+
rel_pos_resized = rel_pos
|
119 |
+
|
120 |
+
# Scale the coords with short length if shapes for q and k are different.
|
121 |
+
q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
|
122 |
+
k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
|
123 |
+
relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
|
124 |
+
|
125 |
+
return rel_pos_resized[relative_coords.long()]
|
126 |
+
|
127 |
+
|
128 |
+
def add_decomposed_rel_pos(attn, q, rel_pos_h, rel_pos_w, q_size, k_size):
|
129 |
+
"""
|
130 |
+
Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
|
131 |
+
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
|
132 |
+
Args:
|
133 |
+
attn (Tensor): attention map.
|
134 |
+
q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
|
135 |
+
rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
|
136 |
+
rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
|
137 |
+
q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
|
138 |
+
k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
|
139 |
+
|
140 |
+
Returns:
|
141 |
+
attn (Tensor): attention map with added relative positional embeddings.
|
142 |
+
"""
|
143 |
+
q_h, q_w = q_size
|
144 |
+
k_h, k_w = k_size
|
145 |
+
Rh = get_rel_pos(q_h, k_h, rel_pos_h)
|
146 |
+
Rw = get_rel_pos(q_w, k_w, rel_pos_w)
|
147 |
+
|
148 |
+
B, _, dim = q.shape
|
149 |
+
r_q = q.reshape(B, q_h, q_w, dim)
|
150 |
+
rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
|
151 |
+
rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
|
152 |
+
|
153 |
+
attn = (
|
154 |
+
attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
|
155 |
+
).view(B, q_h * q_w, k_h * k_w)
|
156 |
+
|
157 |
+
return attn
|
158 |
+
|
159 |
+
|
160 |
+
def get_abs_pos(abs_pos, has_cls_token, hw):
|
161 |
+
"""
|
162 |
+
Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token
|
163 |
+
dimension for the original embeddings.
|
164 |
+
Args:
|
165 |
+
abs_pos (Tensor): absolute positional embeddings with (1, num_position, C).
|
166 |
+
has_cls_token (bool): If true, has 1 embedding in abs_pos for cls token.
|
167 |
+
hw (Tuple): size of input image tokens.
|
168 |
+
|
169 |
+
Returns:
|
170 |
+
Absolute positional embeddings after processing with shape (1, H, W, C)
|
171 |
+
"""
|
172 |
+
h, w = hw
|
173 |
+
if has_cls_token:
|
174 |
+
abs_pos = abs_pos[:, 1:]
|
175 |
+
xy_num = abs_pos.shape[1]
|
176 |
+
size = int(math.sqrt(xy_num))
|
177 |
+
assert size * size == xy_num
|
178 |
+
|
179 |
+
if size != h or size != w:
|
180 |
+
new_abs_pos = F.interpolate(
|
181 |
+
abs_pos.reshape(1, size, size, -1).permute(0, 3, 1, 2),
|
182 |
+
size=(h, w),
|
183 |
+
mode="bicubic",
|
184 |
+
align_corners=False,
|
185 |
+
)
|
186 |
+
|
187 |
+
return new_abs_pos.permute(0, 2, 3, 1)
|
188 |
+
else:
|
189 |
+
return abs_pos.reshape(1, h, w, -1)
|
190 |
+
|
191 |
+
|
192 |
+
class PatchEmbed(nn.Module):
|
193 |
+
"""
|
194 |
+
Image to Patch Embedding.
|
195 |
+
"""
|
196 |
+
|
197 |
+
def __init__(
|
198 |
+
self, kernel_size=(16, 16), stride=(16, 16), padding=(0, 0), in_chans=3, embed_dim=768
|
199 |
+
):
|
200 |
+
"""
|
201 |
+
Args:
|
202 |
+
kernel_size (Tuple): kernel size of the projection layer.
|
203 |
+
stride (Tuple): stride of the projection layer.
|
204 |
+
padding (Tuple): padding size of the projection layer.
|
205 |
+
in_chans (int): Number of input image channels.
|
206 |
+
embed_dim (int): embed_dim (int): Patch embedding dimension.
|
207 |
+
"""
|
208 |
+
super().__init__()
|
209 |
+
|
210 |
+
self.proj = nn.Conv2d(
|
211 |
+
in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
|
212 |
+
)
|
213 |
+
|
214 |
+
def forward(self, x):
|
215 |
+
x = self.proj(x)
|
216 |
+
# B C H W -> B H W C
|
217 |
+
x = x.permute(0, 2, 3, 1)
|
218 |
+
return x
|
219 |
+
|
220 |
+
|
221 |
+
|
222 |
+
|
223 |
+
from math import pi
|
224 |
+
|
225 |
+
import torch
|
226 |
+
from torch import nn
|
227 |
+
|
228 |
+
from einops import rearrange, repeat
|
229 |
+
|
230 |
+
|
231 |
+
|
232 |
+
def broadcat(tensors, dim = -1):
|
233 |
+
num_tensors = len(tensors)
|
234 |
+
shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
|
235 |
+
assert len(shape_lens) == 1, 'tensors must all have the same number of dimensions'
|
236 |
+
shape_len = list(shape_lens)[0]
|
237 |
+
dim = (dim + shape_len) if dim < 0 else dim
|
238 |
+
dims = list(zip(*map(lambda t: list(t.shape), tensors)))
|
239 |
+
expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
|
240 |
+
assert all([*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]), 'invalid dimensions for broadcastable concatentation'
|
241 |
+
max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
|
242 |
+
expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
|
243 |
+
expanded_dims.insert(dim, (dim, dims[dim]))
|
244 |
+
expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
|
245 |
+
tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
|
246 |
+
return torch.cat(tensors, dim = dim)
|
247 |
+
|
248 |
+
|
249 |
+
|
250 |
+
def rotate_half(x):
|
251 |
+
x = rearrange(x, '... (d r) -> ... d r', r = 2)
|
252 |
+
x1, x2 = x.unbind(dim = -1)
|
253 |
+
x = torch.stack((-x2, x1), dim = -1)
|
254 |
+
return rearrange(x, '... d r -> ... (d r)')
|
255 |
+
|
256 |
+
|
257 |
+
|
258 |
+
class VisionRotaryEmbedding(nn.Module):
|
259 |
+
def __init__(
|
260 |
+
self,
|
261 |
+
dim,
|
262 |
+
pt_seq_len,
|
263 |
+
ft_seq_len=None,
|
264 |
+
custom_freqs = None,
|
265 |
+
freqs_for = 'lang',
|
266 |
+
theta = 10000,
|
267 |
+
max_freq = 10,
|
268 |
+
num_freqs = 1,
|
269 |
+
):
|
270 |
+
super().__init__()
|
271 |
+
if custom_freqs:
|
272 |
+
freqs = custom_freqs
|
273 |
+
elif freqs_for == 'lang':
|
274 |
+
freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
|
275 |
+
elif freqs_for == 'pixel':
|
276 |
+
freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi
|
277 |
+
elif freqs_for == 'constant':
|
278 |
+
freqs = torch.ones(num_freqs).float()
|
279 |
+
else:
|
280 |
+
raise ValueError(f'unknown modality {freqs_for}')
|
281 |
+
|
282 |
+
if ft_seq_len is None: ft_seq_len = pt_seq_len
|
283 |
+
t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
|
284 |
+
|
285 |
+
freqs_h = torch.einsum('..., f -> ... f', t, freqs)
|
286 |
+
freqs_h = repeat(freqs_h, '... n -> ... (n r)', r = 2)
|
287 |
+
|
288 |
+
freqs_w = torch.einsum('..., f -> ... f', t, freqs)
|
289 |
+
freqs_w = repeat(freqs_w, '... n -> ... (n r)', r = 2)
|
290 |
+
|
291 |
+
freqs = broadcat((freqs_h[:, None, :], freqs_w[None, :, :]), dim = -1)
|
292 |
+
|
293 |
+
self.register_buffer("freqs_cos", freqs.cos())
|
294 |
+
self.register_buffer("freqs_sin", freqs.sin())
|
295 |
+
|
296 |
+
print('======== shape of rope freq', self.freqs_cos.shape, '========')
|
297 |
+
|
298 |
+
def forward(self, t, start_index = 0):
|
299 |
+
rot_dim = self.freqs_cos.shape[-1]
|
300 |
+
end_index = start_index + rot_dim
|
301 |
+
assert rot_dim <= t.shape[-1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}'
|
302 |
+
t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:]
|
303 |
+
t = (t * self.freqs_cos) + (rotate_half(t) * self.freqs_sin)
|
304 |
+
return torch.cat((t_left, t, t_right), dim = -1)
|
305 |
+
|
306 |
+
|
307 |
+
|
308 |
+
|
309 |
+
class VisionRotaryEmbeddingFast(nn.Module):
|
310 |
+
def __init__(
|
311 |
+
self,
|
312 |
+
dim,
|
313 |
+
pt_seq_len=16,
|
314 |
+
ft_seq_len=None,
|
315 |
+
custom_freqs = None,
|
316 |
+
freqs_for = 'lang',
|
317 |
+
theta = 10000,
|
318 |
+
max_freq = 10,
|
319 |
+
num_freqs = 1,
|
320 |
+
):
|
321 |
+
super().__init__()
|
322 |
+
if custom_freqs:
|
323 |
+
freqs = custom_freqs
|
324 |
+
elif freqs_for == 'lang':
|
325 |
+
freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
|
326 |
+
elif freqs_for == 'pixel':
|
327 |
+
freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi
|
328 |
+
elif freqs_for == 'constant':
|
329 |
+
freqs = torch.ones(num_freqs).float()
|
330 |
+
else:
|
331 |
+
raise ValueError(f'unknown modality {freqs_for}')
|
332 |
+
|
333 |
+
if ft_seq_len is None: ft_seq_len = pt_seq_len
|
334 |
+
t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
|
335 |
+
|
336 |
+
freqs = torch.einsum('..., f -> ... f', t, freqs)
|
337 |
+
freqs = repeat(freqs, '... n -> ... (n r)', r = 2)
|
338 |
+
freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim = -1)
|
339 |
+
|
340 |
+
freqs_cos = freqs.cos().view(-1, freqs.shape[-1])
|
341 |
+
freqs_sin = freqs.sin().view(-1, freqs.shape[-1])
|
342 |
+
|
343 |
+
self.register_buffer("freqs_cos", freqs_cos)
|
344 |
+
self.register_buffer("freqs_sin", freqs_sin)
|
345 |
+
|
346 |
+
print('======== shape of rope freq', self.freqs_cos.shape, '========')
|
347 |
+
|
348 |
+
# def forward(self, t): return t * self.freqs_cos + rotate_half(t) * self.freqs_sin
|
349 |
+
def forward(self, t):
|
350 |
+
if t.shape[2] != self.freqs_cos.shape[0]:
|
351 |
+
t_len = t.shape[2]
|
352 |
+
output = t * self.freqs_cos[:t_len] + rotate_half(t) * self.freqs_sin[:t_len]
|
353 |
+
else:
|
354 |
+
output = t * self.freqs_cos + rotate_half(t) * self.freqs_sin
|
355 |
+
return output
|
356 |
+
|
GLEE/glee/backbone/internimage.py
ADDED
@@ -0,0 +1,737 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# InternImage
|
3 |
+
# Copyright (c) 2022 OpenGVLab
|
4 |
+
# Licensed under The MIT License [see LICENSE for details]
|
5 |
+
# --------------------------------------------------------
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
import torch.utils.checkpoint as checkpoint
|
11 |
+
from timm.models.layers import trunc_normal_, DropPath
|
12 |
+
|
13 |
+
from detectron2.utils.logger import setup_logger
|
14 |
+
from detectron2.modeling.backbone import Backbone
|
15 |
+
|
16 |
+
|
17 |
+
from detectron2.modeling import BACKBONE_REGISTRY, Backbone, ShapeSpec
|
18 |
+
from .ops_dcnv3 import modules as opsm
|
19 |
+
|
20 |
+
|
21 |
+
|
22 |
+
class to_channels_first(nn.Module):
|
23 |
+
|
24 |
+
def __init__(self):
|
25 |
+
super().__init__()
|
26 |
+
|
27 |
+
def forward(self, x):
|
28 |
+
return x.permute(0, 3, 1, 2)
|
29 |
+
|
30 |
+
|
31 |
+
class to_channels_last(nn.Module):
|
32 |
+
|
33 |
+
def __init__(self):
|
34 |
+
super().__init__()
|
35 |
+
|
36 |
+
def forward(self, x):
|
37 |
+
return x.permute(0, 2, 3, 1)
|
38 |
+
|
39 |
+
|
40 |
+
def build_norm_layer(dim,
|
41 |
+
norm_layer,
|
42 |
+
in_format='channels_last',
|
43 |
+
out_format='channels_last',
|
44 |
+
eps=1e-6):
|
45 |
+
layers = []
|
46 |
+
if norm_layer == 'BN':
|
47 |
+
if in_format == 'channels_last':
|
48 |
+
layers.append(to_channels_first())
|
49 |
+
layers.append(nn.BatchNorm2d(dim))
|
50 |
+
if out_format == 'channels_last':
|
51 |
+
layers.append(to_channels_last())
|
52 |
+
elif norm_layer == 'LN':
|
53 |
+
if in_format == 'channels_first':
|
54 |
+
layers.append(to_channels_last())
|
55 |
+
layers.append(nn.LayerNorm(dim, eps=eps))
|
56 |
+
if out_format == 'channels_first':
|
57 |
+
layers.append(to_channels_first())
|
58 |
+
else:
|
59 |
+
raise NotImplementedError(
|
60 |
+
f'build_norm_layer does not support {norm_layer}')
|
61 |
+
return nn.Sequential(*layers)
|
62 |
+
|
63 |
+
|
64 |
+
def build_act_layer(act_layer):
|
65 |
+
if act_layer == 'ReLU':
|
66 |
+
return nn.ReLU(inplace=True)
|
67 |
+
elif act_layer == 'SiLU':
|
68 |
+
return nn.SiLU(inplace=True)
|
69 |
+
elif act_layer == 'GELU':
|
70 |
+
return nn.GELU()
|
71 |
+
|
72 |
+
raise NotImplementedError(f'build_act_layer does not support {act_layer}')
|
73 |
+
|
74 |
+
|
75 |
+
class CrossAttention(nn.Module):
|
76 |
+
r""" Cross Attention Module
|
77 |
+
Args:
|
78 |
+
dim (int): Number of input channels.
|
79 |
+
num_heads (int): Number of attention heads. Default: 8
|
80 |
+
qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
|
81 |
+
Default: False.
|
82 |
+
qk_scale (float | None, optional): Override default qk scale of
|
83 |
+
head_dim ** -0.5 if set. Default: None.
|
84 |
+
attn_drop (float, optional): Dropout ratio of attention weight.
|
85 |
+
Default: 0.0
|
86 |
+
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
|
87 |
+
attn_head_dim (int, optional): Dimension of attention head.
|
88 |
+
out_dim (int, optional): Dimension of output.
|
89 |
+
"""
|
90 |
+
|
91 |
+
def __init__(self,
|
92 |
+
dim,
|
93 |
+
num_heads=8,
|
94 |
+
qkv_bias=False,
|
95 |
+
qk_scale=None,
|
96 |
+
attn_drop=0.,
|
97 |
+
proj_drop=0.,
|
98 |
+
attn_head_dim=None,
|
99 |
+
out_dim=None):
|
100 |
+
super().__init__()
|
101 |
+
if out_dim is None:
|
102 |
+
out_dim = dim
|
103 |
+
self.num_heads = num_heads
|
104 |
+
head_dim = dim // num_heads
|
105 |
+
if attn_head_dim is not None:
|
106 |
+
head_dim = attn_head_dim
|
107 |
+
all_head_dim = head_dim * self.num_heads
|
108 |
+
self.scale = qk_scale or head_dim ** -0.5
|
109 |
+
assert all_head_dim == dim
|
110 |
+
|
111 |
+
self.q = nn.Linear(dim, all_head_dim, bias=False)
|
112 |
+
self.k = nn.Linear(dim, all_head_dim, bias=False)
|
113 |
+
self.v = nn.Linear(dim, all_head_dim, bias=False)
|
114 |
+
|
115 |
+
if qkv_bias:
|
116 |
+
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
|
117 |
+
self.k_bias = nn.Parameter(torch.zeros(all_head_dim))
|
118 |
+
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
|
119 |
+
else:
|
120 |
+
self.q_bias = None
|
121 |
+
self.k_bias = None
|
122 |
+
self.v_bias = None
|
123 |
+
|
124 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
125 |
+
self.proj = nn.Linear(all_head_dim, out_dim)
|
126 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
127 |
+
|
128 |
+
def forward(self, x, k=None, v=None):
|
129 |
+
B, N, C = x.shape
|
130 |
+
N_k = k.shape[1]
|
131 |
+
N_v = v.shape[1]
|
132 |
+
|
133 |
+
q_bias, k_bias, v_bias = None, None, None
|
134 |
+
if self.q_bias is not None:
|
135 |
+
q_bias = self.q_bias
|
136 |
+
k_bias = self.k_bias
|
137 |
+
v_bias = self.v_bias
|
138 |
+
|
139 |
+
q = F.linear(input=x, weight=self.q.weight, bias=q_bias)
|
140 |
+
q = q.reshape(B, N, 1, self.num_heads,
|
141 |
+
-1).permute(2, 0, 3, 1,
|
142 |
+
4).squeeze(0) # (B, N_head, N_q, dim)
|
143 |
+
|
144 |
+
k = F.linear(input=k, weight=self.k.weight, bias=k_bias)
|
145 |
+
k = k.reshape(B, N_k, 1, self.num_heads, -1).permute(2, 0, 3, 1,
|
146 |
+
4).squeeze(0)
|
147 |
+
|
148 |
+
v = F.linear(input=v, weight=self.v.weight, bias=v_bias)
|
149 |
+
v = v.reshape(B, N_v, 1, self.num_heads, -1).permute(2, 0, 3, 1,
|
150 |
+
4).squeeze(0)
|
151 |
+
|
152 |
+
q = q * self.scale
|
153 |
+
attn = (q @ k.transpose(-2, -1)) # (B, N_head, N_q, N_k)
|
154 |
+
|
155 |
+
attn = attn.softmax(dim=-1)
|
156 |
+
attn = self.attn_drop(attn)
|
157 |
+
|
158 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
|
159 |
+
x = self.proj(x)
|
160 |
+
x = self.proj_drop(x)
|
161 |
+
|
162 |
+
return x
|
163 |
+
|
164 |
+
|
165 |
+
class AttentiveBlock(nn.Module):
|
166 |
+
r"""Attentive Block
|
167 |
+
Args:
|
168 |
+
dim (int): Number of input channels.
|
169 |
+
num_heads (int): Number of attention heads. Default: 8
|
170 |
+
qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
|
171 |
+
Default: False.
|
172 |
+
qk_scale (float | None, optional): Override default qk scale of
|
173 |
+
head_dim ** -0.5 if set. Default: None.
|
174 |
+
drop (float, optional): Dropout rate. Default: 0.0.
|
175 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0.
|
176 |
+
drop_path (float | tuple[float], optional): Stochastic depth rate.
|
177 |
+
Default: 0.0.
|
178 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm.
|
179 |
+
attn_head_dim (int, optional): Dimension of attention head. Default: None.
|
180 |
+
out_dim (int, optional): Dimension of output. Default: None.
|
181 |
+
"""
|
182 |
+
|
183 |
+
def __init__(self,
|
184 |
+
dim,
|
185 |
+
num_heads,
|
186 |
+
qkv_bias=False,
|
187 |
+
qk_scale=None,
|
188 |
+
drop=0.,
|
189 |
+
attn_drop=0.,
|
190 |
+
drop_path=0.,
|
191 |
+
norm_layer="LN",
|
192 |
+
attn_head_dim=None,
|
193 |
+
out_dim=None):
|
194 |
+
super().__init__()
|
195 |
+
|
196 |
+
self.norm1_q = build_norm_layer(dim, norm_layer, eps=1e-6)
|
197 |
+
self.norm1_k = build_norm_layer(dim, norm_layer, eps=1e-6)
|
198 |
+
self.norm1_v = build_norm_layer(dim, norm_layer, eps=1e-6)
|
199 |
+
self.cross_dcn = CrossAttention(dim,
|
200 |
+
num_heads=num_heads,
|
201 |
+
qkv_bias=qkv_bias,
|
202 |
+
qk_scale=qk_scale,
|
203 |
+
attn_drop=attn_drop,
|
204 |
+
proj_drop=drop,
|
205 |
+
attn_head_dim=attn_head_dim,
|
206 |
+
out_dim=out_dim)
|
207 |
+
|
208 |
+
self.drop_path = DropPath(
|
209 |
+
drop_path) if drop_path > 0. else nn.Identity()
|
210 |
+
|
211 |
+
def forward(self,
|
212 |
+
x_q,
|
213 |
+
x_kv,
|
214 |
+
pos_q,
|
215 |
+
pos_k,
|
216 |
+
bool_masked_pos,
|
217 |
+
rel_pos_bias=None):
|
218 |
+
x_q = self.norm1_q(x_q + pos_q)
|
219 |
+
x_k = self.norm1_k(x_kv + pos_k)
|
220 |
+
x_v = self.norm1_v(x_kv)
|
221 |
+
|
222 |
+
x = self.cross_dcn(x_q, k=x_k, v=x_v)
|
223 |
+
|
224 |
+
return x
|
225 |
+
|
226 |
+
|
227 |
+
class AttentionPoolingBlock(AttentiveBlock):
|
228 |
+
|
229 |
+
def forward(self, x):
|
230 |
+
x_q = x.mean(1, keepdim=True)
|
231 |
+
x_kv = x
|
232 |
+
pos_q, pos_k = 0, 0
|
233 |
+
x = super().forward(x_q, x_kv, pos_q, pos_k,
|
234 |
+
bool_masked_pos=None,
|
235 |
+
rel_pos_bias=None)
|
236 |
+
x = x.squeeze(1)
|
237 |
+
return x
|
238 |
+
|
239 |
+
|
240 |
+
class StemLayer(nn.Module):
|
241 |
+
r""" Stem layer of InternImage
|
242 |
+
Args:
|
243 |
+
in_chans (int): number of input channels
|
244 |
+
out_chans (int): number of output channels
|
245 |
+
act_layer (str): activation layer
|
246 |
+
norm_layer (str): normalization layer
|
247 |
+
"""
|
248 |
+
|
249 |
+
def __init__(self,
|
250 |
+
in_chans=3,
|
251 |
+
out_chans=96,
|
252 |
+
act_layer='GELU',
|
253 |
+
norm_layer='BN'):
|
254 |
+
super().__init__()
|
255 |
+
self.conv1 = nn.Conv2d(in_chans,
|
256 |
+
out_chans // 2,
|
257 |
+
kernel_size=3,
|
258 |
+
stride=2,
|
259 |
+
padding=1)
|
260 |
+
self.norm1 = build_norm_layer(out_chans // 2, norm_layer,
|
261 |
+
'channels_first', 'channels_first')
|
262 |
+
self.act = build_act_layer(act_layer)
|
263 |
+
self.conv2 = nn.Conv2d(out_chans // 2,
|
264 |
+
out_chans,
|
265 |
+
kernel_size=3,
|
266 |
+
stride=2,
|
267 |
+
padding=1)
|
268 |
+
self.norm2 = build_norm_layer(out_chans, norm_layer, 'channels_first',
|
269 |
+
'channels_last')
|
270 |
+
|
271 |
+
def forward(self, x):
|
272 |
+
x = self.conv1(x)
|
273 |
+
x = self.norm1(x)
|
274 |
+
x = self.act(x)
|
275 |
+
x = self.conv2(x)
|
276 |
+
x = self.norm2(x)
|
277 |
+
return x
|
278 |
+
|
279 |
+
|
280 |
+
class DownsampleLayer(nn.Module):
|
281 |
+
r""" Downsample layer of InternImage
|
282 |
+
Args:
|
283 |
+
channels (int): number of input channels
|
284 |
+
norm_layer (str): normalization layer
|
285 |
+
"""
|
286 |
+
|
287 |
+
def __init__(self, channels, norm_layer='LN'):
|
288 |
+
super().__init__()
|
289 |
+
self.conv = nn.Conv2d(channels,
|
290 |
+
2 * channels,
|
291 |
+
kernel_size=3,
|
292 |
+
stride=2,
|
293 |
+
padding=1,
|
294 |
+
bias=False)
|
295 |
+
self.norm = build_norm_layer(2 * channels, norm_layer,
|
296 |
+
'channels_first', 'channels_last')
|
297 |
+
|
298 |
+
def forward(self, x):
|
299 |
+
x = self.conv(x.permute(0, 3, 1, 2))
|
300 |
+
x = self.norm(x)
|
301 |
+
return x
|
302 |
+
|
303 |
+
|
304 |
+
class MLPLayer(nn.Module):
|
305 |
+
r""" MLP layer of InternImage
|
306 |
+
Args:
|
307 |
+
in_features (int): number of input features
|
308 |
+
hidden_features (int): number of hidden features
|
309 |
+
out_features (int): number of output features
|
310 |
+
act_layer (str): activation layer
|
311 |
+
drop (float): dropout rate
|
312 |
+
"""
|
313 |
+
|
314 |
+
def __init__(self,
|
315 |
+
in_features,
|
316 |
+
hidden_features=None,
|
317 |
+
out_features=None,
|
318 |
+
act_layer='GELU',
|
319 |
+
drop=0.):
|
320 |
+
super().__init__()
|
321 |
+
out_features = out_features or in_features
|
322 |
+
hidden_features = hidden_features or in_features
|
323 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
324 |
+
self.act = build_act_layer(act_layer)
|
325 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
326 |
+
self.drop = nn.Dropout(drop)
|
327 |
+
|
328 |
+
def forward(self, x):
|
329 |
+
x = self.fc1(x)
|
330 |
+
x = self.act(x)
|
331 |
+
x = self.drop(x)
|
332 |
+
x = self.fc2(x)
|
333 |
+
x = self.drop(x)
|
334 |
+
return x
|
335 |
+
|
336 |
+
|
337 |
+
class InternImageLayer(nn.Module):
|
338 |
+
r""" Basic layer of InternImage
|
339 |
+
Args:
|
340 |
+
core_op (nn.Module): core operation of InternImage
|
341 |
+
channels (int): number of input channels
|
342 |
+
groups (list): Groups of each block.
|
343 |
+
mlp_ratio (float): ratio of mlp hidden features to input channels
|
344 |
+
drop (float): dropout rate
|
345 |
+
drop_path (float): drop path rate
|
346 |
+
act_layer (str): activation layer
|
347 |
+
norm_layer (str): normalization layer
|
348 |
+
post_norm (bool): whether to use post normalization
|
349 |
+
layer_scale (float): layer scale
|
350 |
+
offset_scale (float): offset scale
|
351 |
+
with_cp (bool): whether to use checkpoint
|
352 |
+
"""
|
353 |
+
|
354 |
+
def __init__(self,
|
355 |
+
core_op,
|
356 |
+
channels,
|
357 |
+
groups,
|
358 |
+
mlp_ratio=4.,
|
359 |
+
drop=0.,
|
360 |
+
drop_path=0.,
|
361 |
+
act_layer='GELU',
|
362 |
+
norm_layer='LN',
|
363 |
+
post_norm=False,
|
364 |
+
layer_scale=None,
|
365 |
+
offset_scale=1.0,
|
366 |
+
with_cp=False,
|
367 |
+
dw_kernel_size=None, # for InternImage-H/G
|
368 |
+
res_post_norm=False, # for InternImage-H/G
|
369 |
+
center_feature_scale=False): # for InternImage-H/G
|
370 |
+
super().__init__()
|
371 |
+
self.channels = channels
|
372 |
+
self.groups = groups
|
373 |
+
self.mlp_ratio = mlp_ratio
|
374 |
+
self.with_cp = with_cp
|
375 |
+
|
376 |
+
self.norm1 = build_norm_layer(channels, 'LN')
|
377 |
+
self.post_norm = post_norm
|
378 |
+
self.dcn = core_op(
|
379 |
+
channels=channels,
|
380 |
+
kernel_size=3,
|
381 |
+
stride=1,
|
382 |
+
pad=1,
|
383 |
+
dilation=1,
|
384 |
+
group=groups,
|
385 |
+
offset_scale=offset_scale,
|
386 |
+
act_layer=act_layer,
|
387 |
+
norm_layer=norm_layer,
|
388 |
+
dw_kernel_size=dw_kernel_size, # for InternImage-H/G
|
389 |
+
center_feature_scale=center_feature_scale) # for InternImage-H/G
|
390 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. \
|
391 |
+
else nn.Identity()
|
392 |
+
self.norm2 = build_norm_layer(channels, 'LN')
|
393 |
+
self.mlp = MLPLayer(in_features=channels,
|
394 |
+
hidden_features=int(channels * mlp_ratio),
|
395 |
+
act_layer=act_layer,
|
396 |
+
drop=drop)
|
397 |
+
self.layer_scale = layer_scale is not None
|
398 |
+
if self.layer_scale:
|
399 |
+
self.gamma1 = nn.Parameter(layer_scale * torch.ones(channels),
|
400 |
+
requires_grad=True)
|
401 |
+
self.gamma2 = nn.Parameter(layer_scale * torch.ones(channels),
|
402 |
+
requires_grad=True)
|
403 |
+
self.res_post_norm = res_post_norm
|
404 |
+
if res_post_norm:
|
405 |
+
self.res_post_norm1 = build_norm_layer(channels, 'LN')
|
406 |
+
self.res_post_norm2 = build_norm_layer(channels, 'LN')
|
407 |
+
|
408 |
+
def forward(self, x):
|
409 |
+
|
410 |
+
def _inner_forward(x):
|
411 |
+
if not self.layer_scale:
|
412 |
+
if self.post_norm:
|
413 |
+
x = x + self.drop_path(self.norm1(self.dcn(x)))
|
414 |
+
x = x + self.drop_path(self.norm2(self.mlp(x)))
|
415 |
+
elif self.res_post_norm: # for InternImage-H/G
|
416 |
+
x = x + self.drop_path(self.res_post_norm1(self.dcn(self.norm1(x))))
|
417 |
+
x = x + self.drop_path(self.res_post_norm2(self.mlp(self.norm2(x))))
|
418 |
+
else:
|
419 |
+
x = x + self.drop_path(self.dcn(self.norm1(x)))
|
420 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
421 |
+
return x
|
422 |
+
if self.post_norm:
|
423 |
+
x = x + self.drop_path(self.gamma1 * self.norm1(self.dcn(x)))
|
424 |
+
x = x + self.drop_path(self.gamma2 * self.norm2(self.mlp(x)))
|
425 |
+
else:
|
426 |
+
x = x + self.drop_path(self.gamma1 * self.dcn(self.norm1(x)))
|
427 |
+
x = x + self.drop_path(self.gamma2 * self.mlp(self.norm2(x)))
|
428 |
+
return x
|
429 |
+
|
430 |
+
if self.with_cp and x.requires_grad:
|
431 |
+
x = checkpoint.checkpoint(_inner_forward, x)
|
432 |
+
else:
|
433 |
+
x = _inner_forward(x)
|
434 |
+
return x
|
435 |
+
|
436 |
+
|
437 |
+
class InternImageBlock(nn.Module):
|
438 |
+
r""" Block of InternImage
|
439 |
+
Args:
|
440 |
+
core_op (nn.Module): core operation of InternImage
|
441 |
+
channels (int): number of input channels
|
442 |
+
depths (list): Depth of each block.
|
443 |
+
groups (list): Groups of each block.
|
444 |
+
mlp_ratio (float): ratio of mlp hidden features to input channels
|
445 |
+
drop (float): dropout rate
|
446 |
+
drop_path (float): drop path rate
|
447 |
+
act_layer (str): activation layer
|
448 |
+
norm_layer (str): normalization layer
|
449 |
+
post_norm (bool): whether to use post normalization
|
450 |
+
layer_scale (float): layer scale
|
451 |
+
offset_scale (float): offset scale
|
452 |
+
with_cp (bool): whether to use checkpoint
|
453 |
+
"""
|
454 |
+
|
455 |
+
def __init__(self,
|
456 |
+
core_op,
|
457 |
+
channels,
|
458 |
+
depth,
|
459 |
+
groups,
|
460 |
+
downsample=True,
|
461 |
+
mlp_ratio=4.,
|
462 |
+
drop=0.,
|
463 |
+
drop_path=0.,
|
464 |
+
act_layer='GELU',
|
465 |
+
norm_layer='LN',
|
466 |
+
post_norm=False,
|
467 |
+
offset_scale=1.0,
|
468 |
+
layer_scale=None,
|
469 |
+
with_cp=False,
|
470 |
+
dw_kernel_size=None, # for InternImage-H/G
|
471 |
+
post_norm_block_ids=None, # for InternImage-H/G
|
472 |
+
res_post_norm=False, # for InternImage-H/G
|
473 |
+
center_feature_scale=False): # for InternImage-H/G
|
474 |
+
super().__init__()
|
475 |
+
self.channels = channels
|
476 |
+
self.depth = depth
|
477 |
+
self.post_norm = post_norm
|
478 |
+
self.center_feature_scale = center_feature_scale
|
479 |
+
|
480 |
+
self.blocks = nn.ModuleList([
|
481 |
+
InternImageLayer(
|
482 |
+
core_op=core_op,
|
483 |
+
channels=channels,
|
484 |
+
groups=groups,
|
485 |
+
mlp_ratio=mlp_ratio,
|
486 |
+
drop=drop,
|
487 |
+
drop_path=drop_path[i] if isinstance(
|
488 |
+
drop_path, list) else drop_path,
|
489 |
+
act_layer=act_layer,
|
490 |
+
norm_layer=norm_layer,
|
491 |
+
post_norm=post_norm,
|
492 |
+
layer_scale=layer_scale,
|
493 |
+
offset_scale=offset_scale,
|
494 |
+
with_cp=with_cp,
|
495 |
+
dw_kernel_size=dw_kernel_size, # for InternImage-H/G
|
496 |
+
res_post_norm=res_post_norm, # for InternImage-H/G
|
497 |
+
center_feature_scale=center_feature_scale # for InternImage-H/G
|
498 |
+
) for i in range(depth)
|
499 |
+
])
|
500 |
+
if not self.post_norm or center_feature_scale:
|
501 |
+
self.norm = build_norm_layer(channels, 'LN')
|
502 |
+
self.post_norm_block_ids = post_norm_block_ids
|
503 |
+
if post_norm_block_ids is not None: # for InternImage-H/G
|
504 |
+
self.post_norms = nn.ModuleList(
|
505 |
+
[build_norm_layer(channels, 'LN', eps=1e-6) for _ in post_norm_block_ids]
|
506 |
+
)
|
507 |
+
self.downsample = DownsampleLayer(
|
508 |
+
channels=channels, norm_layer=norm_layer) if downsample else None
|
509 |
+
|
510 |
+
def forward(self, x, return_wo_downsample=False):
|
511 |
+
for i, blk in enumerate(self.blocks):
|
512 |
+
x = blk(x)
|
513 |
+
if (self.post_norm_block_ids is not None) and (i in self.post_norm_block_ids):
|
514 |
+
index = self.post_norm_block_ids.index(i)
|
515 |
+
x = self.post_norms[index](x) # for InternImage-H/G
|
516 |
+
if not self.post_norm or self.center_feature_scale:
|
517 |
+
x = self.norm(x)
|
518 |
+
if return_wo_downsample:
|
519 |
+
x_ = x
|
520 |
+
if self.downsample is not None:
|
521 |
+
x = self.downsample(x)
|
522 |
+
|
523 |
+
if return_wo_downsample:
|
524 |
+
return x, x_
|
525 |
+
return x
|
526 |
+
|
527 |
+
class InternImage(Backbone):
|
528 |
+
r""" InternImage
|
529 |
+
A PyTorch impl of : `InternImage: Exploring Large-Scale Vision Foundation Models with Deformable Convolutions` -
|
530 |
+
https://arxiv.org/pdf/2103.14030
|
531 |
+
Args:
|
532 |
+
core_op (str): Core operator. Default: 'DCNv3'
|
533 |
+
channels (int): Number of the first stage. Default: 64
|
534 |
+
depths (list): Depth of each block. Default: [3, 4, 18, 5]
|
535 |
+
groups (list): Groups of each block. Default: [3, 6, 12, 24]
|
536 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
|
537 |
+
drop_rate (float): Probability of an element to be zeroed. Default: 0.
|
538 |
+
drop_path_rate (float): Stochastic depth rate. Default: 0.
|
539 |
+
act_layer (str): Activation layer. Default: 'GELU'
|
540 |
+
norm_layer (str): Normalization layer. Default: 'LN'
|
541 |
+
layer_scale (bool): Whether to use layer scale. Default: False
|
542 |
+
cls_scale (bool): Whether to use class scale. Default: False
|
543 |
+
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
544 |
+
dw_kernel_size (int): Size of the dwconv. Default: None
|
545 |
+
level2_post_norm (bool): Whether to use level2 post norm. Default: False
|
546 |
+
level2_post_norm_block_ids (list): Indexes of post norm blocks. Default: None
|
547 |
+
res_post_norm (bool): Whether to use res post norm. Default: False
|
548 |
+
center_feature_scale (bool): Whether to use center feature scale. Default: False
|
549 |
+
"""
|
550 |
+
|
551 |
+
def __init__(self,
|
552 |
+
core_op='DCNv3',
|
553 |
+
channels=64,
|
554 |
+
depths=[3, 4, 18, 5],
|
555 |
+
groups=[3, 6, 12, 24],
|
556 |
+
mlp_ratio=4.,
|
557 |
+
drop_rate=0.,
|
558 |
+
drop_path_rate=0.2,
|
559 |
+
drop_path_type='linear',
|
560 |
+
act_layer='GELU',
|
561 |
+
norm_layer='LN',
|
562 |
+
layer_scale=None,
|
563 |
+
offset_scale=1.0,
|
564 |
+
post_norm=False,
|
565 |
+
with_cp=False,
|
566 |
+
dw_kernel_size=None, # for InternImage-H/G
|
567 |
+
level2_post_norm=False, # for InternImage-H/G
|
568 |
+
level2_post_norm_block_ids=None, # for InternImage-H/G
|
569 |
+
res_post_norm=False, # for InternImage-H/G
|
570 |
+
center_feature_scale=False, # for InternImage-H/G
|
571 |
+
out_indices=(0, 1, 2, 3),
|
572 |
+
init_cfg=None,
|
573 |
+
**kwargs):
|
574 |
+
super().__init__()
|
575 |
+
self.core_op = core_op
|
576 |
+
self.num_levels = len(depths)
|
577 |
+
self.depths = depths
|
578 |
+
self.channels = channels
|
579 |
+
self.num_features = int(channels * 2**(self.num_levels - 1))
|
580 |
+
self.post_norm = post_norm
|
581 |
+
self.mlp_ratio = mlp_ratio
|
582 |
+
self.init_cfg = init_cfg
|
583 |
+
self.out_indices = out_indices
|
584 |
+
self.level2_post_norm_block_ids = level2_post_norm_block_ids
|
585 |
+
logger = setup_logger(name="InternImage")
|
586 |
+
logger.info(f'using core type: {core_op}')
|
587 |
+
logger.info(f'using activation layer: {act_layer}')
|
588 |
+
logger.info(f'using main norm layer: {norm_layer}')
|
589 |
+
logger.info(f'using dpr: {drop_path_type}, {drop_path_rate}')
|
590 |
+
logger.info(f"level2_post_norm: {level2_post_norm}")
|
591 |
+
logger.info(f"level2_post_norm_block_ids: {level2_post_norm_block_ids}")
|
592 |
+
logger.info(f"res_post_norm: {res_post_norm}")
|
593 |
+
|
594 |
+
in_chans = 3
|
595 |
+
self.patch_embed = StemLayer(in_chans=in_chans,
|
596 |
+
out_chans=channels,
|
597 |
+
act_layer=act_layer,
|
598 |
+
norm_layer=norm_layer)
|
599 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
600 |
+
|
601 |
+
dpr = [
|
602 |
+
x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
|
603 |
+
]
|
604 |
+
if drop_path_type == 'uniform':
|
605 |
+
for i in range(len(dpr)):
|
606 |
+
dpr[i] = drop_path_rate
|
607 |
+
|
608 |
+
self.levels = nn.ModuleList()
|
609 |
+
for i in range(self.num_levels):
|
610 |
+
post_norm_block_ids = level2_post_norm_block_ids if level2_post_norm and (
|
611 |
+
i == 2) else None # for InternImage-H/G
|
612 |
+
level = InternImageBlock(
|
613 |
+
core_op=getattr(opsm, core_op),
|
614 |
+
channels=int(channels * 2**i),
|
615 |
+
depth=depths[i],
|
616 |
+
groups=groups[i],
|
617 |
+
mlp_ratio=self.mlp_ratio,
|
618 |
+
drop=drop_rate,
|
619 |
+
drop_path=dpr[sum(depths[:i]):sum(depths[:i + 1])],
|
620 |
+
act_layer=act_layer,
|
621 |
+
norm_layer=norm_layer,
|
622 |
+
post_norm=post_norm,
|
623 |
+
downsample=(i < self.num_levels - 1),
|
624 |
+
layer_scale=layer_scale,
|
625 |
+
offset_scale=offset_scale,
|
626 |
+
with_cp=with_cp,
|
627 |
+
dw_kernel_size=dw_kernel_size, # for InternImage-H/G
|
628 |
+
post_norm_block_ids=post_norm_block_ids, # for InternImage-H/G
|
629 |
+
res_post_norm=res_post_norm, # for InternImage-H/G
|
630 |
+
center_feature_scale=center_feature_scale # for InternImage-H/G
|
631 |
+
)
|
632 |
+
self.levels.append(level)
|
633 |
+
|
634 |
+
self.num_layers = len(depths)
|
635 |
+
self.apply(self._init_weights)
|
636 |
+
self.apply(self._init_deform_weights)
|
637 |
+
|
638 |
+
# add basic info for d2 backbone
|
639 |
+
self._out_features = ["res{}".format(i+2) for i in self.out_indices]
|
640 |
+
self._out_feature_channels = {
|
641 |
+
"res{}".format(i+2): self.channels * 2**i for i in self.out_indices
|
642 |
+
}
|
643 |
+
self._out_feature_strides = {"res{}".format(i+2): 2 ** (i + 2) for i in self.out_indices}
|
644 |
+
self._size_devisibility = 32
|
645 |
+
|
646 |
+
|
647 |
+
def _init_weights(self, m):
|
648 |
+
if isinstance(m, nn.Linear):
|
649 |
+
trunc_normal_(m.weight, std=.02)
|
650 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
651 |
+
nn.init.constant_(m.bias, 0)
|
652 |
+
elif isinstance(m, nn.LayerNorm):
|
653 |
+
nn.init.constant_(m.bias, 0)
|
654 |
+
nn.init.constant_(m.weight, 1.0)
|
655 |
+
|
656 |
+
def _init_deform_weights(self, m):
|
657 |
+
if isinstance(m, getattr(opsm, self.core_op)):
|
658 |
+
m._reset_parameters()
|
659 |
+
|
660 |
+
def forward(self, x):
|
661 |
+
x = self.patch_embed(x)
|
662 |
+
x = self.pos_drop(x)
|
663 |
+
|
664 |
+
# d2 need dict output
|
665 |
+
# seq_out = []
|
666 |
+
seq_out = {}
|
667 |
+
for level_idx, level in enumerate(self.levels):
|
668 |
+
x, x_ = level(x, return_wo_downsample=True)
|
669 |
+
if level_idx in self.out_indices:
|
670 |
+
# seq_out.append(x_.permute(0, 3, 1, 2).contiguous())
|
671 |
+
seq_out["res{}".format(level_idx+2)] = x_.permute(0, 3, 1, 2).contiguous()
|
672 |
+
return seq_out
|
673 |
+
|
674 |
+
@BACKBONE_REGISTRY.register()
|
675 |
+
class D2InternImage(InternImage):
|
676 |
+
def __init__(self, cfg, input_shape):
|
677 |
+
|
678 |
+
super().__init__(
|
679 |
+
core_op= cfg.MODEL.INTERNIMAGE.CORE_OP ,
|
680 |
+
channels=cfg.MODEL.INTERNIMAGE.CHANNELS,
|
681 |
+
depths=cfg.MODEL.INTERNIMAGE.DEPTHS,
|
682 |
+
groups=cfg.MODEL.INTERNIMAGE.GROUPS,
|
683 |
+
mlp_ratio= cfg.MODEL.INTERNIMAGE.MLP_RATIO ,
|
684 |
+
drop_path_rate=cfg.MODEL.INTERNIMAGE.DROP_PATH_RATE,
|
685 |
+
norm_layer=cfg.MODEL.INTERNIMAGE.NORM_LAYER,
|
686 |
+
layer_scale=cfg.MODEL.INTERNIMAGE.LAYER_SCALE ,
|
687 |
+
offset_scale=cfg.MODEL.INTERNIMAGE.OFFSET_SCALE,
|
688 |
+
post_norm=cfg.MODEL.INTERNIMAGE.POST_NORM,
|
689 |
+
with_cp=cfg.MODEL.INTERNIMAGE.WITH_CP ,
|
690 |
+
out_indices=cfg.MODEL.INTERNIMAGE.OUT_IINDICES,
|
691 |
+
dw_kernel_size= cfg.MODEL.INTERNIMAGE.DW_KERNEL_SIZE, # for InternImage-H/G
|
692 |
+
res_post_norm= cfg.MODEL.INTERNIMAGE.RES_POST_NORM, # for InternImage-H/G
|
693 |
+
level2_post_norm= cfg.MODEL.INTERNIMAGE.LEVEL2_POST_NORM, # for InternImage-H/G
|
694 |
+
level2_post_norm_block_ids= cfg.MODEL.INTERNIMAGE.LEVEL2_POST_NORM_BLOCK_IDS, # for InternImage-H/G
|
695 |
+
center_feature_scale= cfg.MODEL.INTERNIMAGE.CENTER_FEATURE_SCALE, # for InternImage-H/G
|
696 |
+
|
697 |
+
|
698 |
+
)
|
699 |
+
|
700 |
+
|
701 |
+
pretrained_weight = cfg.MODEL.INTERNIMAGE.PRETRAINED_WEIGHT
|
702 |
+
if pretrained_weight:
|
703 |
+
checkpoint = torch.load(pretrained_weight, map_location='cpu')
|
704 |
+
print(f'\nload pretrain weight from {pretrained_weight} \n')
|
705 |
+
self.load_state_dict(checkpoint['model'], strict=False)
|
706 |
+
|
707 |
+
|
708 |
+
def forward(self, x):
|
709 |
+
"""
|
710 |
+
Args:
|
711 |
+
x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``.
|
712 |
+
Returns:
|
713 |
+
dict[str->Tensor]: names and the corresponding features
|
714 |
+
"""
|
715 |
+
assert (
|
716 |
+
x.dim() == 4
|
717 |
+
), f"SwinTransformer takes an input of shape (N, C, H, W). Got {x.shape} instead!"
|
718 |
+
outputs = {}
|
719 |
+
y = super().forward(x)
|
720 |
+
for k in y.keys():
|
721 |
+
if k in self._out_features:
|
722 |
+
outputs[k] = y[k]
|
723 |
+
return outputs
|
724 |
+
|
725 |
+
def output_shape(self):
|
726 |
+
return {
|
727 |
+
name: ShapeSpec(
|
728 |
+
channels=self._out_feature_channels[name], stride=self._out_feature_strides[name]
|
729 |
+
)
|
730 |
+
for name in self._out_features
|
731 |
+
}
|
732 |
+
|
733 |
+
@property
|
734 |
+
def size_divisibility(self):
|
735 |
+
return 32
|
736 |
+
|
737 |
+
|
GLEE/glee/backbone/registry.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_model_entrypoints = {}
|
2 |
+
|
3 |
+
|
4 |
+
def register_backbone(fn):
|
5 |
+
module_name_split = fn.__module__.split('.')
|
6 |
+
model_name = module_name_split[-1]
|
7 |
+
_model_entrypoints[model_name] = fn
|
8 |
+
return fn
|
9 |
+
|
10 |
+
def model_entrypoints(model_name):
|
11 |
+
return _model_entrypoints[model_name]
|
12 |
+
|
13 |
+
def is_model(model_name):
|
14 |
+
return model_name in _model_entrypoints
|
GLEE/glee/backbone/resnet.py
ADDED
@@ -0,0 +1,731 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
import pickle
|
3 |
+
import numpy as np
|
4 |
+
from typing import Any, Dict
|
5 |
+
import fvcore.nn.weight_init as weight_init
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from torch import nn
|
9 |
+
|
10 |
+
|
11 |
+
from .backbone import Backbone
|
12 |
+
from .registry import register_backbone
|
13 |
+
|
14 |
+
from detectron2.layers import (
|
15 |
+
CNNBlockBase,
|
16 |
+
Conv2d,
|
17 |
+
DeformConv,
|
18 |
+
ModulatedDeformConv,
|
19 |
+
ShapeSpec,
|
20 |
+
get_norm,
|
21 |
+
)
|
22 |
+
from detectron2.utils.file_io import PathManager
|
23 |
+
|
24 |
+
__all__ = [
|
25 |
+
"ResNetBlockBase",
|
26 |
+
"BasicBlock",
|
27 |
+
"BottleneckBlock",
|
28 |
+
"DeformBottleneckBlock",
|
29 |
+
"BasicStem",
|
30 |
+
"ResNet",
|
31 |
+
"make_stage",
|
32 |
+
"get_resnet_backbone",
|
33 |
+
]
|
34 |
+
|
35 |
+
|
36 |
+
class BasicBlock(CNNBlockBase):
|
37 |
+
"""
|
38 |
+
The basic residual block for ResNet-18 and ResNet-34 defined in :paper:`ResNet`,
|
39 |
+
with two 3x3 conv layers and a projection shortcut if needed.
|
40 |
+
"""
|
41 |
+
|
42 |
+
def __init__(self, in_channels, out_channels, *, stride=1, norm="BN"):
|
43 |
+
"""
|
44 |
+
Args:
|
45 |
+
in_channels (int): Number of input channels.
|
46 |
+
out_channels (int): Number of output channels.
|
47 |
+
stride (int): Stride for the first conv.
|
48 |
+
norm (str or callable): normalization for all conv layers.
|
49 |
+
See :func:`layers.get_norm` for supported format.
|
50 |
+
"""
|
51 |
+
super().__init__(in_channels, out_channels, stride)
|
52 |
+
|
53 |
+
if in_channels != out_channels:
|
54 |
+
self.shortcut = Conv2d(
|
55 |
+
in_channels,
|
56 |
+
out_channels,
|
57 |
+
kernel_size=1,
|
58 |
+
stride=stride,
|
59 |
+
bias=False,
|
60 |
+
norm=get_norm(norm, out_channels),
|
61 |
+
)
|
62 |
+
else:
|
63 |
+
self.shortcut = None
|
64 |
+
|
65 |
+
self.conv1 = Conv2d(
|
66 |
+
in_channels,
|
67 |
+
out_channels,
|
68 |
+
kernel_size=3,
|
69 |
+
stride=stride,
|
70 |
+
padding=1,
|
71 |
+
bias=False,
|
72 |
+
norm=get_norm(norm, out_channels),
|
73 |
+
)
|
74 |
+
|
75 |
+
self.conv2 = Conv2d(
|
76 |
+
out_channels,
|
77 |
+
out_channels,
|
78 |
+
kernel_size=3,
|
79 |
+
stride=1,
|
80 |
+
padding=1,
|
81 |
+
bias=False,
|
82 |
+
norm=get_norm(norm, out_channels),
|
83 |
+
)
|
84 |
+
|
85 |
+
for layer in [self.conv1, self.conv2, self.shortcut]:
|
86 |
+
if layer is not None: # shortcut can be None
|
87 |
+
weight_init.c2_msra_fill(layer)
|
88 |
+
|
89 |
+
def forward(self, x):
|
90 |
+
out = self.conv1(x)
|
91 |
+
out = F.relu_(out)
|
92 |
+
out = self.conv2(out)
|
93 |
+
|
94 |
+
if self.shortcut is not None:
|
95 |
+
shortcut = self.shortcut(x)
|
96 |
+
else:
|
97 |
+
shortcut = x
|
98 |
+
|
99 |
+
out += shortcut
|
100 |
+
out = F.relu_(out)
|
101 |
+
return out
|
102 |
+
|
103 |
+
|
104 |
+
class BottleneckBlock(CNNBlockBase):
|
105 |
+
"""
|
106 |
+
The standard bottleneck residual block used by ResNet-50, 101 and 152
|
107 |
+
defined in :paper:`ResNet`. It contains 3 conv layers with kernels
|
108 |
+
1x1, 3x3, 1x1, and a projection shortcut if needed.
|
109 |
+
"""
|
110 |
+
|
111 |
+
def __init__(
|
112 |
+
self,
|
113 |
+
in_channels,
|
114 |
+
out_channels,
|
115 |
+
*,
|
116 |
+
bottleneck_channels,
|
117 |
+
stride=1,
|
118 |
+
num_groups=1,
|
119 |
+
norm="BN",
|
120 |
+
stride_in_1x1=False,
|
121 |
+
dilation=1,
|
122 |
+
):
|
123 |
+
"""
|
124 |
+
Args:
|
125 |
+
bottleneck_channels (int): number of output channels for the 3x3
|
126 |
+
"bottleneck" conv layers.
|
127 |
+
num_groups (int): number of groups for the 3x3 conv layer.
|
128 |
+
norm (str or callable): normalization for all conv layers.
|
129 |
+
See :func:`layers.get_norm` for supported format.
|
130 |
+
stride_in_1x1 (bool): when stride>1, whether to put stride in the
|
131 |
+
first 1x1 convolution or the bottleneck 3x3 convolution.
|
132 |
+
dilation (int): the dilation rate of the 3x3 conv layer.
|
133 |
+
"""
|
134 |
+
super().__init__(in_channels, out_channels, stride)
|
135 |
+
|
136 |
+
if in_channels != out_channels:
|
137 |
+
self.shortcut = Conv2d(
|
138 |
+
in_channels,
|
139 |
+
out_channels,
|
140 |
+
kernel_size=1,
|
141 |
+
stride=stride,
|
142 |
+
bias=False,
|
143 |
+
norm=get_norm(norm, out_channels),
|
144 |
+
)
|
145 |
+
else:
|
146 |
+
self.shortcut = None
|
147 |
+
|
148 |
+
# The original MSRA ResNet models have stride in the first 1x1 conv
|
149 |
+
# The subsequent fb.torch.resnet and Caffe2 ResNe[X]t implementations have
|
150 |
+
# stride in the 3x3 conv
|
151 |
+
stride_1x1, stride_3x3 = (stride, 1) if stride_in_1x1 else (1, stride)
|
152 |
+
|
153 |
+
self.conv1 = Conv2d(
|
154 |
+
in_channels,
|
155 |
+
bottleneck_channels,
|
156 |
+
kernel_size=1,
|
157 |
+
stride=stride_1x1,
|
158 |
+
bias=False,
|
159 |
+
norm=get_norm(norm, bottleneck_channels),
|
160 |
+
)
|
161 |
+
|
162 |
+
self.conv2 = Conv2d(
|
163 |
+
bottleneck_channels,
|
164 |
+
bottleneck_channels,
|
165 |
+
kernel_size=3,
|
166 |
+
stride=stride_3x3,
|
167 |
+
padding=1 * dilation,
|
168 |
+
bias=False,
|
169 |
+
groups=num_groups,
|
170 |
+
dilation=dilation,
|
171 |
+
norm=get_norm(norm, bottleneck_channels),
|
172 |
+
)
|
173 |
+
|
174 |
+
self.conv3 = Conv2d(
|
175 |
+
bottleneck_channels,
|
176 |
+
out_channels,
|
177 |
+
kernel_size=1,
|
178 |
+
bias=False,
|
179 |
+
norm=get_norm(norm, out_channels),
|
180 |
+
)
|
181 |
+
|
182 |
+
for layer in [self.conv1, self.conv2, self.conv3, self.shortcut]:
|
183 |
+
if layer is not None: # shortcut can be None
|
184 |
+
weight_init.c2_msra_fill(layer)
|
185 |
+
|
186 |
+
# Zero-initialize the last normalization in each residual branch,
|
187 |
+
# so that at the beginning, the residual branch starts with zeros,
|
188 |
+
# and each residual block behaves like an identity.
|
189 |
+
# See Sec 5.1 in "Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour":
|
190 |
+
# "For BN layers, the learnable scaling coefficient γ is initialized
|
191 |
+
# to be 1, except for each residual block's last BN
|
192 |
+
# where γ is initialized to be 0."
|
193 |
+
|
194 |
+
# nn.init.constant_(self.conv3.norm.weight, 0)
|
195 |
+
# TODO this somehow hurts performance when training GN models from scratch.
|
196 |
+
# Add it as an option when we need to use this code to train a backbone.
|
197 |
+
|
198 |
+
def forward(self, x):
|
199 |
+
out = self.conv1(x)
|
200 |
+
out = F.relu_(out)
|
201 |
+
|
202 |
+
out = self.conv2(out)
|
203 |
+
out = F.relu_(out)
|
204 |
+
|
205 |
+
out = self.conv3(out)
|
206 |
+
|
207 |
+
if self.shortcut is not None:
|
208 |
+
shortcut = self.shortcut(x)
|
209 |
+
else:
|
210 |
+
shortcut = x
|
211 |
+
|
212 |
+
out += shortcut
|
213 |
+
out = F.relu_(out)
|
214 |
+
return out
|
215 |
+
|
216 |
+
|
217 |
+
class DeformBottleneckBlock(CNNBlockBase):
|
218 |
+
"""
|
219 |
+
Similar to :class:`BottleneckBlock`, but with :paper:`deformable conv <deformconv>`
|
220 |
+
in the 3x3 convolution.
|
221 |
+
"""
|
222 |
+
|
223 |
+
def __init__(
|
224 |
+
self,
|
225 |
+
in_channels,
|
226 |
+
out_channels,
|
227 |
+
*,
|
228 |
+
bottleneck_channels,
|
229 |
+
stride=1,
|
230 |
+
num_groups=1,
|
231 |
+
norm="BN",
|
232 |
+
stride_in_1x1=False,
|
233 |
+
dilation=1,
|
234 |
+
deform_modulated=False,
|
235 |
+
deform_num_groups=1,
|
236 |
+
):
|
237 |
+
super().__init__(in_channels, out_channels, stride)
|
238 |
+
self.deform_modulated = deform_modulated
|
239 |
+
|
240 |
+
if in_channels != out_channels:
|
241 |
+
self.shortcut = Conv2d(
|
242 |
+
in_channels,
|
243 |
+
out_channels,
|
244 |
+
kernel_size=1,
|
245 |
+
stride=stride,
|
246 |
+
bias=False,
|
247 |
+
norm=get_norm(norm, out_channels),
|
248 |
+
)
|
249 |
+
else:
|
250 |
+
self.shortcut = None
|
251 |
+
|
252 |
+
stride_1x1, stride_3x3 = (stride, 1) if stride_in_1x1 else (1, stride)
|
253 |
+
|
254 |
+
self.conv1 = Conv2d(
|
255 |
+
in_channels,
|
256 |
+
bottleneck_channels,
|
257 |
+
kernel_size=1,
|
258 |
+
stride=stride_1x1,
|
259 |
+
bias=False,
|
260 |
+
norm=get_norm(norm, bottleneck_channels),
|
261 |
+
)
|
262 |
+
|
263 |
+
if deform_modulated:
|
264 |
+
deform_conv_op = ModulatedDeformConv
|
265 |
+
# offset channels are 2 or 3 (if with modulated) * kernel_size * kernel_size
|
266 |
+
offset_channels = 27
|
267 |
+
else:
|
268 |
+
deform_conv_op = DeformConv
|
269 |
+
offset_channels = 18
|
270 |
+
|
271 |
+
self.conv2_offset = Conv2d(
|
272 |
+
bottleneck_channels,
|
273 |
+
offset_channels * deform_num_groups,
|
274 |
+
kernel_size=3,
|
275 |
+
stride=stride_3x3,
|
276 |
+
padding=1 * dilation,
|
277 |
+
dilation=dilation,
|
278 |
+
)
|
279 |
+
self.conv2 = deform_conv_op(
|
280 |
+
bottleneck_channels,
|
281 |
+
bottleneck_channels,
|
282 |
+
kernel_size=3,
|
283 |
+
stride=stride_3x3,
|
284 |
+
padding=1 * dilation,
|
285 |
+
bias=False,
|
286 |
+
groups=num_groups,
|
287 |
+
dilation=dilation,
|
288 |
+
deformable_groups=deform_num_groups,
|
289 |
+
norm=get_norm(norm, bottleneck_channels),
|
290 |
+
)
|
291 |
+
|
292 |
+
self.conv3 = Conv2d(
|
293 |
+
bottleneck_channels,
|
294 |
+
out_channels,
|
295 |
+
kernel_size=1,
|
296 |
+
bias=False,
|
297 |
+
norm=get_norm(norm, out_channels),
|
298 |
+
)
|
299 |
+
|
300 |
+
for layer in [self.conv1, self.conv2, self.conv3, self.shortcut]:
|
301 |
+
if layer is not None: # shortcut can be None
|
302 |
+
weight_init.c2_msra_fill(layer)
|
303 |
+
|
304 |
+
nn.init.constant_(self.conv2_offset.weight, 0)
|
305 |
+
nn.init.constant_(self.conv2_offset.bias, 0)
|
306 |
+
|
307 |
+
def forward(self, x):
|
308 |
+
out = self.conv1(x)
|
309 |
+
out = F.relu_(out)
|
310 |
+
|
311 |
+
if self.deform_modulated:
|
312 |
+
offset_mask = self.conv2_offset(out)
|
313 |
+
offset_x, offset_y, mask = torch.chunk(offset_mask, 3, dim=1)
|
314 |
+
offset = torch.cat((offset_x, offset_y), dim=1)
|
315 |
+
mask = mask.sigmoid()
|
316 |
+
out = self.conv2(out, offset, mask)
|
317 |
+
else:
|
318 |
+
offset = self.conv2_offset(out)
|
319 |
+
out = self.conv2(out, offset)
|
320 |
+
out = F.relu_(out)
|
321 |
+
|
322 |
+
out = self.conv3(out)
|
323 |
+
|
324 |
+
if self.shortcut is not None:
|
325 |
+
shortcut = self.shortcut(x)
|
326 |
+
else:
|
327 |
+
shortcut = x
|
328 |
+
|
329 |
+
out += shortcut
|
330 |
+
out = F.relu_(out)
|
331 |
+
return out
|
332 |
+
|
333 |
+
|
334 |
+
class BasicStem(CNNBlockBase):
|
335 |
+
"""
|
336 |
+
The standard ResNet stem (layers before the first residual block),
|
337 |
+
with a conv, relu and max_pool.
|
338 |
+
"""
|
339 |
+
|
340 |
+
def __init__(self, in_channels=3, out_channels=64, norm="BN"):
|
341 |
+
"""
|
342 |
+
Args:
|
343 |
+
norm (str or callable): norm after the first conv layer.
|
344 |
+
See :func:`layers.get_norm` for supported format.
|
345 |
+
"""
|
346 |
+
super().__init__(in_channels, out_channels, 4)
|
347 |
+
self.in_channels = in_channels
|
348 |
+
self.conv1 = Conv2d(
|
349 |
+
in_channels,
|
350 |
+
out_channels,
|
351 |
+
kernel_size=7,
|
352 |
+
stride=2,
|
353 |
+
padding=3,
|
354 |
+
bias=False,
|
355 |
+
norm=get_norm(norm, out_channels),
|
356 |
+
)
|
357 |
+
weight_init.c2_msra_fill(self.conv1)
|
358 |
+
|
359 |
+
def forward(self, x):
|
360 |
+
x = self.conv1(x)
|
361 |
+
x = F.relu_(x)
|
362 |
+
x = F.max_pool2d(x, kernel_size=3, stride=2, padding=1)
|
363 |
+
return x
|
364 |
+
|
365 |
+
|
366 |
+
class ResNet(Backbone):
|
367 |
+
"""
|
368 |
+
Implement :paper:`ResNet`.
|
369 |
+
"""
|
370 |
+
|
371 |
+
def __init__(self, stem, stages, num_classes=None, out_features=None, freeze_at=0):
|
372 |
+
"""
|
373 |
+
Args:
|
374 |
+
stem (nn.Module): a stem module
|
375 |
+
stages (list[list[CNNBlockBase]]): several (typically 4) stages,
|
376 |
+
each contains multiple :class:`CNNBlockBase`.
|
377 |
+
num_classes (None or int): if None, will not perform classification.
|
378 |
+
Otherwise, will create a linear layer.
|
379 |
+
out_features (list[str]): name of the layers whose outputs should
|
380 |
+
be returned in forward. Can be anything in "stem", "linear", or "res2" ...
|
381 |
+
If None, will return the output of the last layer.
|
382 |
+
freeze_at (int): The number of stages at the beginning to freeze.
|
383 |
+
see :meth:`freeze` for detailed explanation.
|
384 |
+
"""
|
385 |
+
super().__init__()
|
386 |
+
self.stem = stem
|
387 |
+
self.num_classes = num_classes
|
388 |
+
|
389 |
+
current_stride = self.stem.stride
|
390 |
+
self._out_feature_strides = {"stem": current_stride}
|
391 |
+
self._out_feature_channels = {"stem": self.stem.out_channels}
|
392 |
+
|
393 |
+
self.stage_names, self.stages = [], []
|
394 |
+
|
395 |
+
if out_features is not None:
|
396 |
+
# Avoid keeping unused layers in this module. They consume extra memory
|
397 |
+
# and may cause allreduce to fail
|
398 |
+
num_stages = max(
|
399 |
+
[{"res2": 1, "res3": 2, "res4": 3, "res5": 4}.get(f, 0) for f in out_features]
|
400 |
+
)
|
401 |
+
stages = stages[:num_stages]
|
402 |
+
for i, blocks in enumerate(stages):
|
403 |
+
assert len(blocks) > 0, len(blocks)
|
404 |
+
for block in blocks:
|
405 |
+
assert isinstance(block, CNNBlockBase), block
|
406 |
+
|
407 |
+
name = "res" + str(i + 2)
|
408 |
+
stage = nn.Sequential(*blocks)
|
409 |
+
|
410 |
+
self.add_module(name, stage)
|
411 |
+
self.stage_names.append(name)
|
412 |
+
self.stages.append(stage)
|
413 |
+
|
414 |
+
self._out_feature_strides[name] = current_stride = int(
|
415 |
+
current_stride * np.prod([k.stride for k in blocks])
|
416 |
+
)
|
417 |
+
self._out_feature_channels[name] = curr_channels = blocks[-1].out_channels
|
418 |
+
self.stage_names = tuple(self.stage_names) # Make it static for scripting
|
419 |
+
|
420 |
+
if num_classes is not None:
|
421 |
+
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
422 |
+
self.linear = nn.Linear(curr_channels, num_classes)
|
423 |
+
|
424 |
+
# Sec 5.1 in "Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour":
|
425 |
+
# "The 1000-way fully-connected layer is initialized by
|
426 |
+
# drawing weights from a zero-mean Gaussian with standard deviation of 0.01."
|
427 |
+
nn.init.normal_(self.linear.weight, std=0.01)
|
428 |
+
name = "linear"
|
429 |
+
|
430 |
+
if out_features is None:
|
431 |
+
out_features = [name]
|
432 |
+
self._out_features = out_features
|
433 |
+
assert len(self._out_features)
|
434 |
+
children = [x[0] for x in self.named_children()]
|
435 |
+
for out_feature in self._out_features:
|
436 |
+
assert out_feature in children, "Available children: {}".format(", ".join(children))
|
437 |
+
self.freeze(freeze_at)
|
438 |
+
|
439 |
+
def forward(self, x):
|
440 |
+
"""
|
441 |
+
Args:
|
442 |
+
x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``.
|
443 |
+
|
444 |
+
Returns:
|
445 |
+
dict[str->Tensor]: names and the corresponding features
|
446 |
+
"""
|
447 |
+
assert x.dim() == 4, f"ResNet takes an input of shape (N, C, H, W). Got {x.shape} instead!"
|
448 |
+
outputs = {}
|
449 |
+
x = self.stem(x)
|
450 |
+
if "stem" in self._out_features:
|
451 |
+
outputs["stem"] = x
|
452 |
+
for name, stage in zip(self.stage_names, self.stages):
|
453 |
+
x = stage(x)
|
454 |
+
if name in self._out_features:
|
455 |
+
outputs[name] = x
|
456 |
+
if self.num_classes is not None:
|
457 |
+
x = self.avgpool(x)
|
458 |
+
x = torch.flatten(x, 1)
|
459 |
+
x = self.linear(x)
|
460 |
+
if "linear" in self._out_features:
|
461 |
+
outputs["linear"] = x
|
462 |
+
return outputs
|
463 |
+
|
464 |
+
def output_shape(self):
|
465 |
+
return {
|
466 |
+
name: ShapeSpec(
|
467 |
+
channels=self._out_feature_channels[name], stride=self._out_feature_strides[name]
|
468 |
+
)
|
469 |
+
for name in self._out_features
|
470 |
+
}
|
471 |
+
|
472 |
+
def freeze(self, freeze_at=0):
|
473 |
+
"""
|
474 |
+
Freeze the first several stages of the ResNet. Commonly used in
|
475 |
+
fine-tuning.
|
476 |
+
|
477 |
+
Layers that produce the same feature map spatial size are defined as one
|
478 |
+
"stage" by :paper:`FPN`.
|
479 |
+
|
480 |
+
Args:
|
481 |
+
freeze_at (int): number of stages to freeze.
|
482 |
+
`1` means freezing the stem. `2` means freezing the stem and
|
483 |
+
one residual stage, etc.
|
484 |
+
|
485 |
+
Returns:
|
486 |
+
nn.Module: this ResNet itself
|
487 |
+
"""
|
488 |
+
if freeze_at >= 1:
|
489 |
+
self.stem.freeze()
|
490 |
+
for idx, stage in enumerate(self.stages, start=2):
|
491 |
+
if freeze_at >= idx:
|
492 |
+
for block in stage.children():
|
493 |
+
block.freeze()
|
494 |
+
return self
|
495 |
+
|
496 |
+
@staticmethod
|
497 |
+
def make_stage(block_class, num_blocks, *, in_channels, out_channels, **kwargs):
|
498 |
+
"""
|
499 |
+
Create a list of blocks of the same type that forms one ResNet stage.
|
500 |
+
|
501 |
+
Args:
|
502 |
+
block_class (type): a subclass of CNNBlockBase that's used to create all blocks in this
|
503 |
+
stage. A module of this type must not change spatial resolution of inputs unless its
|
504 |
+
stride != 1.
|
505 |
+
num_blocks (int): number of blocks in this stage
|
506 |
+
in_channels (int): input channels of the entire stage.
|
507 |
+
out_channels (int): output channels of **every block** in the stage.
|
508 |
+
kwargs: other arguments passed to the constructor of
|
509 |
+
`block_class`. If the argument name is "xx_per_block", the
|
510 |
+
argument is a list of values to be passed to each block in the
|
511 |
+
stage. Otherwise, the same argument is passed to every block
|
512 |
+
in the stage.
|
513 |
+
|
514 |
+
Returns:
|
515 |
+
list[CNNBlockBase]: a list of block module.
|
516 |
+
|
517 |
+
Examples:
|
518 |
+
::
|
519 |
+
stage = ResNet.make_stage(
|
520 |
+
BottleneckBlock, 3, in_channels=16, out_channels=64,
|
521 |
+
bottleneck_channels=16, num_groups=1,
|
522 |
+
stride_per_block=[2, 1, 1],
|
523 |
+
dilations_per_block=[1, 1, 2]
|
524 |
+
)
|
525 |
+
|
526 |
+
Usually, layers that produce the same feature map spatial size are defined as one
|
527 |
+
"stage" (in :paper:`FPN`). Under such definition, ``stride_per_block[1:]`` should
|
528 |
+
all be 1.
|
529 |
+
"""
|
530 |
+
blocks = []
|
531 |
+
for i in range(num_blocks):
|
532 |
+
curr_kwargs = {}
|
533 |
+
for k, v in kwargs.items():
|
534 |
+
if k.endswith("_per_block"):
|
535 |
+
assert len(v) == num_blocks, (
|
536 |
+
f"Argument '{k}' of make_stage should have the "
|
537 |
+
f"same length as num_blocks={num_blocks}."
|
538 |
+
)
|
539 |
+
newk = k[: -len("_per_block")]
|
540 |
+
assert newk not in kwargs, f"Cannot call make_stage with both {k} and {newk}!"
|
541 |
+
curr_kwargs[newk] = v[i]
|
542 |
+
else:
|
543 |
+
curr_kwargs[k] = v
|
544 |
+
|
545 |
+
blocks.append(
|
546 |
+
block_class(in_channels=in_channels, out_channels=out_channels, **curr_kwargs)
|
547 |
+
)
|
548 |
+
in_channels = out_channels
|
549 |
+
return blocks
|
550 |
+
|
551 |
+
@staticmethod
|
552 |
+
def make_default_stages(depth, block_class=None, **kwargs):
|
553 |
+
"""
|
554 |
+
Created list of ResNet stages from pre-defined depth (one of 18, 34, 50, 101, 152).
|
555 |
+
If it doesn't create the ResNet variant you need, please use :meth:`make_stage`
|
556 |
+
instead for fine-grained customization.
|
557 |
+
|
558 |
+
Args:
|
559 |
+
depth (int): depth of ResNet
|
560 |
+
block_class (type): the CNN block class. Has to accept
|
561 |
+
`bottleneck_channels` argument for depth > 50.
|
562 |
+
By default it is BasicBlock or BottleneckBlock, based on the
|
563 |
+
depth.
|
564 |
+
kwargs:
|
565 |
+
other arguments to pass to `make_stage`. Should not contain
|
566 |
+
stride and channels, as they are predefined for each depth.
|
567 |
+
|
568 |
+
Returns:
|
569 |
+
list[list[CNNBlockBase]]: modules in all stages; see arguments of
|
570 |
+
:class:`ResNet.__init__`.
|
571 |
+
"""
|
572 |
+
num_blocks_per_stage = {
|
573 |
+
18: [2, 2, 2, 2],
|
574 |
+
34: [3, 4, 6, 3],
|
575 |
+
50: [3, 4, 6, 3],
|
576 |
+
101: [3, 4, 23, 3],
|
577 |
+
152: [3, 8, 36, 3],
|
578 |
+
}[depth]
|
579 |
+
if block_class is None:
|
580 |
+
block_class = BasicBlock if depth < 50 else BottleneckBlock
|
581 |
+
if depth < 50:
|
582 |
+
in_channels = [64, 64, 128, 256]
|
583 |
+
out_channels = [64, 128, 256, 512]
|
584 |
+
else:
|
585 |
+
in_channels = [64, 256, 512, 1024]
|
586 |
+
out_channels = [256, 512, 1024, 2048]
|
587 |
+
ret = []
|
588 |
+
for (n, s, i, o) in zip(num_blocks_per_stage, [1, 2, 2, 2], in_channels, out_channels):
|
589 |
+
if depth >= 50:
|
590 |
+
kwargs["bottleneck_channels"] = o // 4
|
591 |
+
ret.append(
|
592 |
+
ResNet.make_stage(
|
593 |
+
block_class=block_class,
|
594 |
+
num_blocks=n,
|
595 |
+
stride_per_block=[s] + [1] * (n - 1),
|
596 |
+
in_channels=i,
|
597 |
+
out_channels=o,
|
598 |
+
**kwargs,
|
599 |
+
)
|
600 |
+
)
|
601 |
+
return ret
|
602 |
+
|
603 |
+
|
604 |
+
ResNetBlockBase = CNNBlockBase
|
605 |
+
"""
|
606 |
+
Alias for backward compatibiltiy.
|
607 |
+
"""
|
608 |
+
|
609 |
+
|
610 |
+
def make_stage(*args, **kwargs):
|
611 |
+
"""
|
612 |
+
Deprecated alias for backward compatibiltiy.
|
613 |
+
"""
|
614 |
+
return ResNet.make_stage(*args, **kwargs)
|
615 |
+
|
616 |
+
|
617 |
+
def _convert_ndarray_to_tensor(state_dict: Dict[str, Any]) -> None:
|
618 |
+
"""
|
619 |
+
In-place convert all numpy arrays in the state_dict to torch tensor.
|
620 |
+
Args:
|
621 |
+
state_dict (dict): a state-dict to be loaded to the model.
|
622 |
+
Will be modified.
|
623 |
+
"""
|
624 |
+
# model could be an OrderedDict with _metadata attribute
|
625 |
+
# (as returned by Pytorch's state_dict()). We should preserve these
|
626 |
+
# properties.
|
627 |
+
for k in list(state_dict.keys()):
|
628 |
+
v = state_dict[k]
|
629 |
+
if not isinstance(v, np.ndarray) and not isinstance(v, torch.Tensor):
|
630 |
+
raise ValueError(
|
631 |
+
"Unsupported type found in checkpoint! {}: {}".format(k, type(v))
|
632 |
+
)
|
633 |
+
if not isinstance(v, torch.Tensor):
|
634 |
+
state_dict[k] = torch.from_numpy(v)
|
635 |
+
|
636 |
+
|
637 |
+
@register_backbone
|
638 |
+
def get_resnet_backbone(cfg):
|
639 |
+
"""
|
640 |
+
Create a ResNet instance from config.
|
641 |
+
|
642 |
+
Returns:
|
643 |
+
ResNet: a :class:`ResNet` instance.
|
644 |
+
"""
|
645 |
+
res_cfg = cfg['MODEL']['BACKBONE']['RESNETS']
|
646 |
+
|
647 |
+
# need registration of new blocks/stems?
|
648 |
+
norm = res_cfg['NORM']
|
649 |
+
stem = BasicStem(
|
650 |
+
in_channels=res_cfg['STEM_IN_CHANNELS'],
|
651 |
+
out_channels=res_cfg['STEM_OUT_CHANNELS'],
|
652 |
+
norm=norm,
|
653 |
+
)
|
654 |
+
|
655 |
+
# fmt: off
|
656 |
+
freeze_at = res_cfg['FREEZE_AT']
|
657 |
+
out_features = res_cfg['OUT_FEATURES']
|
658 |
+
depth = res_cfg['DEPTH']
|
659 |
+
num_groups = res_cfg['NUM_GROUPS']
|
660 |
+
width_per_group = res_cfg['WIDTH_PER_GROUP']
|
661 |
+
bottleneck_channels = num_groups * width_per_group
|
662 |
+
in_channels = res_cfg['STEM_OUT_CHANNELS']
|
663 |
+
out_channels = res_cfg['RES2_OUT_CHANNELS']
|
664 |
+
stride_in_1x1 = res_cfg['STRIDE_IN_1X1']
|
665 |
+
res5_dilation = res_cfg['RES5_DILATION']
|
666 |
+
deform_on_per_stage = res_cfg['DEFORM_ON_PER_STAGE']
|
667 |
+
deform_modulated = res_cfg['DEFORM_MODULATED']
|
668 |
+
deform_num_groups = res_cfg['DEFORM_NUM_GROUPS']
|
669 |
+
# fmt: on
|
670 |
+
assert res5_dilation in {1, 2}, "res5_dilation cannot be {}.".format(res5_dilation)
|
671 |
+
|
672 |
+
num_blocks_per_stage = {
|
673 |
+
18: [2, 2, 2, 2],
|
674 |
+
34: [3, 4, 6, 3],
|
675 |
+
50: [3, 4, 6, 3],
|
676 |
+
101: [3, 4, 23, 3],
|
677 |
+
152: [3, 8, 36, 3],
|
678 |
+
}[depth]
|
679 |
+
|
680 |
+
if depth in [18, 34]:
|
681 |
+
assert out_channels == 64, "Must set MODEL.RESNETS.RES2_OUT_CHANNELS = 64 for R18/R34"
|
682 |
+
assert not any(
|
683 |
+
deform_on_per_stage
|
684 |
+
), "MODEL.RESNETS.DEFORM_ON_PER_STAGE unsupported for R18/R34"
|
685 |
+
assert res5_dilation == 1, "Must set MODEL.RESNETS.RES5_DILATION = 1 for R18/R34"
|
686 |
+
assert num_groups == 1, "Must set MODEL.RESNETS.NUM_GROUPS = 1 for R18/R34"
|
687 |
+
|
688 |
+
stages = []
|
689 |
+
|
690 |
+
for idx, stage_idx in enumerate(range(2, 6)):
|
691 |
+
# res5_dilation is used this way as a convention in R-FCN & Deformable Conv paper
|
692 |
+
dilation = res5_dilation if stage_idx == 5 else 1
|
693 |
+
first_stride = 1 if idx == 0 or (stage_idx == 5 and dilation == 2) else 2
|
694 |
+
stage_kargs = {
|
695 |
+
"num_blocks": num_blocks_per_stage[idx],
|
696 |
+
"stride_per_block": [first_stride] + [1] * (num_blocks_per_stage[idx] - 1),
|
697 |
+
"in_channels": in_channels,
|
698 |
+
"out_channels": out_channels,
|
699 |
+
"norm": norm,
|
700 |
+
}
|
701 |
+
# Use BasicBlock for R18 and R34.
|
702 |
+
if depth in [18, 34]:
|
703 |
+
stage_kargs["block_class"] = BasicBlock
|
704 |
+
else:
|
705 |
+
stage_kargs["bottleneck_channels"] = bottleneck_channels
|
706 |
+
stage_kargs["stride_in_1x1"] = stride_in_1x1
|
707 |
+
stage_kargs["dilation"] = dilation
|
708 |
+
stage_kargs["num_groups"] = num_groups
|
709 |
+
if deform_on_per_stage[idx]:
|
710 |
+
stage_kargs["block_class"] = DeformBottleneckBlock
|
711 |
+
stage_kargs["deform_modulated"] = deform_modulated
|
712 |
+
stage_kargs["deform_num_groups"] = deform_num_groups
|
713 |
+
else:
|
714 |
+
stage_kargs["block_class"] = BottleneckBlock
|
715 |
+
blocks = ResNet.make_stage(**stage_kargs)
|
716 |
+
in_channels = out_channels
|
717 |
+
out_channels *= 2
|
718 |
+
bottleneck_channels *= 2
|
719 |
+
stages.append(blocks)
|
720 |
+
backbone = ResNet(stem, stages, out_features=out_features, freeze_at=freeze_at)
|
721 |
+
|
722 |
+
if cfg['MODEL']['BACKBONE']['LOAD_PRETRAINED'] is True:
|
723 |
+
filename = cfg['MODEL']['BACKBONE']['PRETRAINED']
|
724 |
+
with PathManager.open(filename, "rb") as f:
|
725 |
+
ckpt = pickle.load(f, encoding="latin1")['model']
|
726 |
+
_convert_ndarray_to_tensor(ckpt)
|
727 |
+
ckpt.pop('stem.fc.weight')
|
728 |
+
ckpt.pop('stem.fc.bias')
|
729 |
+
backbone.load_state_dict(ckpt)
|
730 |
+
|
731 |
+
return backbone
|
GLEE/glee/backbone/swin.py
ADDED
@@ -0,0 +1,783 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# Swin Transformer
|
3 |
+
# Copyright (c) 2021 Microsoft
|
4 |
+
# Licensed under The MIT License [see LICENSE for details]
|
5 |
+
# Written by Ze Liu, Yutong Lin, Yixuan Wei
|
6 |
+
# --------------------------------------------------------
|
7 |
+
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
import torch.nn.functional as F
|
13 |
+
import torch.utils.checkpoint as checkpoint
|
14 |
+
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
15 |
+
|
16 |
+
from detectron2.modeling import BACKBONE_REGISTRY, Backbone, ShapeSpec
|
17 |
+
from .registry import register_backbone
|
18 |
+
|
19 |
+
|
20 |
+
class Mlp(nn.Module):
|
21 |
+
"""Multilayer perceptron."""
|
22 |
+
|
23 |
+
def __init__(
|
24 |
+
self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0
|
25 |
+
):
|
26 |
+
super().__init__()
|
27 |
+
out_features = out_features or in_features
|
28 |
+
hidden_features = hidden_features or in_features
|
29 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
30 |
+
self.act = act_layer()
|
31 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
32 |
+
self.drop = nn.Dropout(drop)
|
33 |
+
|
34 |
+
def forward(self, x):
|
35 |
+
x = self.fc1(x)
|
36 |
+
x = self.act(x)
|
37 |
+
x = self.drop(x)
|
38 |
+
x = self.fc2(x)
|
39 |
+
x = self.drop(x)
|
40 |
+
return x
|
41 |
+
|
42 |
+
|
43 |
+
def window_partition(x, window_size):
|
44 |
+
"""
|
45 |
+
Args:
|
46 |
+
x: (B, H, W, C)
|
47 |
+
window_size (int): window size
|
48 |
+
Returns:
|
49 |
+
windows: (num_windows*B, window_size, window_size, C)
|
50 |
+
"""
|
51 |
+
B, H, W, C = x.shape
|
52 |
+
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
|
53 |
+
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
54 |
+
return windows
|
55 |
+
|
56 |
+
|
57 |
+
def window_reverse(windows, window_size, H, W):
|
58 |
+
"""
|
59 |
+
Args:
|
60 |
+
windows: (num_windows*B, window_size, window_size, C)
|
61 |
+
window_size (int): Window size
|
62 |
+
H (int): Height of image
|
63 |
+
W (int): Width of image
|
64 |
+
Returns:
|
65 |
+
x: (B, H, W, C)
|
66 |
+
"""
|
67 |
+
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
68 |
+
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
|
69 |
+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
70 |
+
return x
|
71 |
+
|
72 |
+
|
73 |
+
class WindowAttention(nn.Module):
|
74 |
+
"""Window based multi-head self attention (W-MSA) module with relative position bias.
|
75 |
+
It supports both of shifted and non-shifted window.
|
76 |
+
Args:
|
77 |
+
dim (int): Number of input channels.
|
78 |
+
window_size (tuple[int]): The height and width of the window.
|
79 |
+
num_heads (int): Number of attention heads.
|
80 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
81 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
|
82 |
+
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
|
83 |
+
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
|
84 |
+
"""
|
85 |
+
|
86 |
+
def __init__(
|
87 |
+
self,
|
88 |
+
dim,
|
89 |
+
window_size,
|
90 |
+
num_heads,
|
91 |
+
qkv_bias=True,
|
92 |
+
qk_scale=None,
|
93 |
+
attn_drop=0.0,
|
94 |
+
proj_drop=0.0,
|
95 |
+
):
|
96 |
+
|
97 |
+
super().__init__()
|
98 |
+
self.dim = dim
|
99 |
+
self.window_size = window_size # Wh, Ww
|
100 |
+
self.num_heads = num_heads
|
101 |
+
head_dim = dim // num_heads
|
102 |
+
self.scale = qk_scale or head_dim ** -0.5
|
103 |
+
|
104 |
+
# define a parameter table of relative position bias
|
105 |
+
self.relative_position_bias_table = nn.Parameter(
|
106 |
+
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
|
107 |
+
) # 2*Wh-1 * 2*Ww-1, nH
|
108 |
+
|
109 |
+
# get pair-wise relative position index for each token inside the window
|
110 |
+
coords_h = torch.arange(self.window_size[0])
|
111 |
+
coords_w = torch.arange(self.window_size[1])
|
112 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
113 |
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
114 |
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
115 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
116 |
+
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
|
117 |
+
relative_coords[:, :, 1] += self.window_size[1] - 1
|
118 |
+
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
|
119 |
+
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
120 |
+
self.register_buffer("relative_position_index", relative_position_index)
|
121 |
+
|
122 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
123 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
124 |
+
self.proj = nn.Linear(dim, dim)
|
125 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
126 |
+
|
127 |
+
trunc_normal_(self.relative_position_bias_table, std=0.02)
|
128 |
+
self.softmax = nn.Softmax(dim=-1)
|
129 |
+
|
130 |
+
def forward(self, x, mask=None):
|
131 |
+
"""Forward function.
|
132 |
+
Args:
|
133 |
+
x: input features with shape of (num_windows*B, N, C)
|
134 |
+
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
|
135 |
+
"""
|
136 |
+
B_, N, C = x.shape
|
137 |
+
qkv = (
|
138 |
+
self.qkv(x)
|
139 |
+
.reshape(B_, N, 3, self.num_heads, C // self.num_heads)
|
140 |
+
.permute(2, 0, 3, 1, 4)
|
141 |
+
)
|
142 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
143 |
+
|
144 |
+
q = q * self.scale
|
145 |
+
attn = q @ k.transpose(-2, -1)
|
146 |
+
|
147 |
+
relative_position_bias = self.relative_position_bias_table[
|
148 |
+
self.relative_position_index.view(-1)
|
149 |
+
].view(
|
150 |
+
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1
|
151 |
+
) # Wh*Ww,Wh*Ww,nH
|
152 |
+
relative_position_bias = relative_position_bias.permute(
|
153 |
+
2, 0, 1
|
154 |
+
).contiguous() # nH, Wh*Ww, Wh*Ww
|
155 |
+
attn = attn + relative_position_bias.unsqueeze(0)
|
156 |
+
|
157 |
+
if mask is not None:
|
158 |
+
nW = mask.shape[0]
|
159 |
+
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
|
160 |
+
attn = attn.view(-1, self.num_heads, N, N)
|
161 |
+
attn = self.softmax(attn)
|
162 |
+
else:
|
163 |
+
attn = self.softmax(attn)
|
164 |
+
|
165 |
+
attn = self.attn_drop(attn)
|
166 |
+
|
167 |
+
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
|
168 |
+
x = self.proj(x)
|
169 |
+
x = self.proj_drop(x)
|
170 |
+
return x
|
171 |
+
|
172 |
+
|
173 |
+
class SwinTransformerBlock(nn.Module):
|
174 |
+
"""Swin Transformer Block.
|
175 |
+
Args:
|
176 |
+
dim (int): Number of input channels.
|
177 |
+
num_heads (int): Number of attention heads.
|
178 |
+
window_size (int): Window size.
|
179 |
+
shift_size (int): Shift size for SW-MSA.
|
180 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
181 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
182 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
183 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
184 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
185 |
+
drop_path (float, optional): Stochastic depth rate. Default: 0.0
|
186 |
+
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
|
187 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
188 |
+
"""
|
189 |
+
|
190 |
+
def __init__(
|
191 |
+
self,
|
192 |
+
dim,
|
193 |
+
num_heads,
|
194 |
+
window_size=7,
|
195 |
+
shift_size=0,
|
196 |
+
mlp_ratio=4.0,
|
197 |
+
qkv_bias=True,
|
198 |
+
qk_scale=None,
|
199 |
+
drop=0.0,
|
200 |
+
attn_drop=0.0,
|
201 |
+
drop_path=0.0,
|
202 |
+
act_layer=nn.GELU,
|
203 |
+
norm_layer=nn.LayerNorm,
|
204 |
+
):
|
205 |
+
super().__init__()
|
206 |
+
self.dim = dim
|
207 |
+
self.num_heads = num_heads
|
208 |
+
self.window_size = window_size
|
209 |
+
self.shift_size = shift_size
|
210 |
+
self.mlp_ratio = mlp_ratio
|
211 |
+
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
|
212 |
+
|
213 |
+
self.norm1 = norm_layer(dim)
|
214 |
+
self.attn = WindowAttention(
|
215 |
+
dim,
|
216 |
+
window_size=to_2tuple(self.window_size),
|
217 |
+
num_heads=num_heads,
|
218 |
+
qkv_bias=qkv_bias,
|
219 |
+
qk_scale=qk_scale,
|
220 |
+
attn_drop=attn_drop,
|
221 |
+
proj_drop=drop,
|
222 |
+
)
|
223 |
+
|
224 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
225 |
+
self.norm2 = norm_layer(dim)
|
226 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
227 |
+
self.mlp = Mlp(
|
228 |
+
in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop
|
229 |
+
)
|
230 |
+
|
231 |
+
self.H = None
|
232 |
+
self.W = None
|
233 |
+
|
234 |
+
def forward(self, x, mask_matrix):
|
235 |
+
"""Forward function.
|
236 |
+
Args:
|
237 |
+
x: Input feature, tensor size (B, H*W, C).
|
238 |
+
H, W: Spatial resolution of the input feature.
|
239 |
+
mask_matrix: Attention mask for cyclic shift.
|
240 |
+
"""
|
241 |
+
B, L, C = x.shape
|
242 |
+
H, W = self.H, self.W
|
243 |
+
assert L == H * W, "input feature has wrong size"
|
244 |
+
|
245 |
+
shortcut = x
|
246 |
+
x = self.norm1(x)
|
247 |
+
x = x.view(B, H, W, C)
|
248 |
+
|
249 |
+
# pad feature maps to multiples of window size
|
250 |
+
pad_l = pad_t = 0
|
251 |
+
pad_r = (self.window_size - W % self.window_size) % self.window_size
|
252 |
+
pad_b = (self.window_size - H % self.window_size) % self.window_size
|
253 |
+
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
|
254 |
+
_, Hp, Wp, _ = x.shape
|
255 |
+
|
256 |
+
# cyclic shift
|
257 |
+
if self.shift_size > 0:
|
258 |
+
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
|
259 |
+
attn_mask = mask_matrix
|
260 |
+
else:
|
261 |
+
shifted_x = x
|
262 |
+
attn_mask = None
|
263 |
+
|
264 |
+
# partition windows
|
265 |
+
x_windows = window_partition(
|
266 |
+
shifted_x, self.window_size
|
267 |
+
) # nW*B, window_size, window_size, C
|
268 |
+
x_windows = x_windows.view(
|
269 |
+
-1, self.window_size * self.window_size, C
|
270 |
+
) # nW*B, window_size*window_size, C
|
271 |
+
|
272 |
+
# W-MSA/SW-MSA
|
273 |
+
attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
|
274 |
+
|
275 |
+
# merge windows
|
276 |
+
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
|
277 |
+
shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
|
278 |
+
|
279 |
+
# reverse cyclic shift
|
280 |
+
if self.shift_size > 0:
|
281 |
+
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
|
282 |
+
else:
|
283 |
+
x = shifted_x
|
284 |
+
|
285 |
+
if pad_r > 0 or pad_b > 0:
|
286 |
+
x = x[:, :H, :W, :].contiguous()
|
287 |
+
|
288 |
+
x = x.view(B, H * W, C)
|
289 |
+
|
290 |
+
# FFN
|
291 |
+
x = shortcut + self.drop_path(x)
|
292 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
293 |
+
|
294 |
+
return x
|
295 |
+
|
296 |
+
|
297 |
+
class PatchMerging(nn.Module):
|
298 |
+
"""Patch Merging Layer
|
299 |
+
Args:
|
300 |
+
dim (int): Number of input channels.
|
301 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
302 |
+
"""
|
303 |
+
|
304 |
+
def __init__(self, dim, norm_layer=nn.LayerNorm):
|
305 |
+
super().__init__()
|
306 |
+
self.dim = dim
|
307 |
+
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
|
308 |
+
self.norm = norm_layer(4 * dim)
|
309 |
+
|
310 |
+
def forward(self, x, H, W):
|
311 |
+
"""Forward function.
|
312 |
+
Args:
|
313 |
+
x: Input feature, tensor size (B, H*W, C).
|
314 |
+
H, W: Spatial resolution of the input feature.
|
315 |
+
"""
|
316 |
+
B, L, C = x.shape
|
317 |
+
assert L == H * W, "input feature has wrong size"
|
318 |
+
|
319 |
+
x = x.view(B, H, W, C)
|
320 |
+
|
321 |
+
# padding
|
322 |
+
pad_input = (H % 2 == 1) or (W % 2 == 1)
|
323 |
+
if pad_input:
|
324 |
+
x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
|
325 |
+
|
326 |
+
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
|
327 |
+
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
|
328 |
+
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
|
329 |
+
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
|
330 |
+
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
|
331 |
+
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
|
332 |
+
|
333 |
+
x = self.norm(x)
|
334 |
+
x = self.reduction(x)
|
335 |
+
|
336 |
+
return x
|
337 |
+
|
338 |
+
|
339 |
+
class BasicLayer(nn.Module):
|
340 |
+
"""A basic Swin Transformer layer for one stage.
|
341 |
+
Args:
|
342 |
+
dim (int): Number of feature channels
|
343 |
+
depth (int): Depths of this stage.
|
344 |
+
num_heads (int): Number of attention head.
|
345 |
+
window_size (int): Local window size. Default: 7.
|
346 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
|
347 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
348 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
349 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
350 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
351 |
+
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
352 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
353 |
+
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
354 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
355 |
+
"""
|
356 |
+
|
357 |
+
def __init__(
|
358 |
+
self,
|
359 |
+
dim,
|
360 |
+
depth,
|
361 |
+
num_heads,
|
362 |
+
window_size=7,
|
363 |
+
mlp_ratio=4.0,
|
364 |
+
qkv_bias=True,
|
365 |
+
qk_scale=None,
|
366 |
+
drop=0.0,
|
367 |
+
attn_drop=0.0,
|
368 |
+
drop_path=0.0,
|
369 |
+
norm_layer=nn.LayerNorm,
|
370 |
+
downsample=None,
|
371 |
+
use_checkpoint=False,
|
372 |
+
):
|
373 |
+
super().__init__()
|
374 |
+
self.window_size = window_size
|
375 |
+
self.shift_size = window_size // 2
|
376 |
+
self.depth = depth
|
377 |
+
self.use_checkpoint = use_checkpoint
|
378 |
+
|
379 |
+
# build blocks
|
380 |
+
self.blocks = nn.ModuleList(
|
381 |
+
[
|
382 |
+
SwinTransformerBlock(
|
383 |
+
dim=dim,
|
384 |
+
num_heads=num_heads,
|
385 |
+
window_size=window_size,
|
386 |
+
shift_size=0 if (i % 2 == 0) else window_size // 2,
|
387 |
+
mlp_ratio=mlp_ratio,
|
388 |
+
qkv_bias=qkv_bias,
|
389 |
+
qk_scale=qk_scale,
|
390 |
+
drop=drop,
|
391 |
+
attn_drop=attn_drop,
|
392 |
+
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
|
393 |
+
norm_layer=norm_layer,
|
394 |
+
)
|
395 |
+
for i in range(depth)
|
396 |
+
]
|
397 |
+
)
|
398 |
+
|
399 |
+
# patch merging layer
|
400 |
+
if downsample is not None:
|
401 |
+
self.downsample = downsample(dim=dim, norm_layer=norm_layer)
|
402 |
+
else:
|
403 |
+
self.downsample = None
|
404 |
+
|
405 |
+
def forward(self, x, H, W):
|
406 |
+
"""Forward function.
|
407 |
+
Args:
|
408 |
+
x: Input feature, tensor size (B, H*W, C).
|
409 |
+
H, W: Spatial resolution of the input feature.
|
410 |
+
"""
|
411 |
+
|
412 |
+
# calculate attention mask for SW-MSA
|
413 |
+
Hp = int(np.ceil(H / self.window_size)) * self.window_size
|
414 |
+
Wp = int(np.ceil(W / self.window_size)) * self.window_size
|
415 |
+
img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
|
416 |
+
h_slices = (
|
417 |
+
slice(0, -self.window_size),
|
418 |
+
slice(-self.window_size, -self.shift_size),
|
419 |
+
slice(-self.shift_size, None),
|
420 |
+
)
|
421 |
+
w_slices = (
|
422 |
+
slice(0, -self.window_size),
|
423 |
+
slice(-self.window_size, -self.shift_size),
|
424 |
+
slice(-self.shift_size, None),
|
425 |
+
)
|
426 |
+
cnt = 0
|
427 |
+
for h in h_slices:
|
428 |
+
for w in w_slices:
|
429 |
+
img_mask[:, h, w, :] = cnt
|
430 |
+
cnt += 1
|
431 |
+
|
432 |
+
mask_windows = window_partition(
|
433 |
+
img_mask, self.window_size
|
434 |
+
) # nW, window_size, window_size, 1
|
435 |
+
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
|
436 |
+
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
437 |
+
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(
|
438 |
+
attn_mask == 0, float(0.0)
|
439 |
+
)
|
440 |
+
|
441 |
+
for blk in self.blocks:
|
442 |
+
blk.H, blk.W = H, W
|
443 |
+
if self.use_checkpoint:
|
444 |
+
x = checkpoint.checkpoint(blk, x, attn_mask)
|
445 |
+
else:
|
446 |
+
x = blk(x, attn_mask)
|
447 |
+
if self.downsample is not None:
|
448 |
+
x_down = self.downsample(x, H, W)
|
449 |
+
Wh, Ww = (H + 1) // 2, (W + 1) // 2
|
450 |
+
return x, H, W, x_down, Wh, Ww
|
451 |
+
else:
|
452 |
+
return x, H, W, x, H, W
|
453 |
+
|
454 |
+
|
455 |
+
class PatchEmbed(nn.Module):
|
456 |
+
"""Image to Patch Embedding
|
457 |
+
Args:
|
458 |
+
patch_size (int): Patch token size. Default: 4.
|
459 |
+
in_chans (int): Number of input image channels. Default: 3.
|
460 |
+
embed_dim (int): Number of linear projection output channels. Default: 96.
|
461 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
462 |
+
"""
|
463 |
+
|
464 |
+
def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
|
465 |
+
super().__init__()
|
466 |
+
patch_size = to_2tuple(patch_size)
|
467 |
+
self.patch_size = patch_size
|
468 |
+
|
469 |
+
self.in_chans = in_chans
|
470 |
+
self.embed_dim = embed_dim
|
471 |
+
|
472 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
473 |
+
if norm_layer is not None:
|
474 |
+
self.norm = norm_layer(embed_dim)
|
475 |
+
else:
|
476 |
+
self.norm = None
|
477 |
+
|
478 |
+
def forward(self, x):
|
479 |
+
"""Forward function."""
|
480 |
+
# padding
|
481 |
+
_, _, H, W = x.size()
|
482 |
+
if W % self.patch_size[1] != 0:
|
483 |
+
x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
|
484 |
+
if H % self.patch_size[0] != 0:
|
485 |
+
x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
|
486 |
+
|
487 |
+
x = self.proj(x) # B C Wh Ww
|
488 |
+
if self.norm is not None:
|
489 |
+
Wh, Ww = x.size(2), x.size(3)
|
490 |
+
x = x.flatten(2).transpose(1, 2)
|
491 |
+
x = self.norm(x)
|
492 |
+
x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
|
493 |
+
|
494 |
+
return x
|
495 |
+
|
496 |
+
|
497 |
+
class SwinTransformer(nn.Module):
|
498 |
+
"""Swin Transformer backbone.
|
499 |
+
A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
|
500 |
+
https://arxiv.org/pdf/2103.14030
|
501 |
+
Args:
|
502 |
+
pretrain_img_size (int): Input image size for training the pretrained model,
|
503 |
+
used in absolute postion embedding. Default 224.
|
504 |
+
patch_size (int | tuple(int)): Patch size. Default: 4.
|
505 |
+
in_chans (int): Number of input image channels. Default: 3.
|
506 |
+
embed_dim (int): Number of linear projection output channels. Default: 96.
|
507 |
+
depths (tuple[int]): Depths of each Swin Transformer stage.
|
508 |
+
num_heads (tuple[int]): Number of attention head of each stage.
|
509 |
+
window_size (int): Window size. Default: 7.
|
510 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
|
511 |
+
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
|
512 |
+
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
|
513 |
+
drop_rate (float): Dropout rate.
|
514 |
+
attn_drop_rate (float): Attention dropout rate. Default: 0.
|
515 |
+
drop_path_rate (float): Stochastic depth rate. Default: 0.2.
|
516 |
+
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
|
517 |
+
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
|
518 |
+
patch_norm (bool): If True, add normalization after patch embedding. Default: True.
|
519 |
+
out_indices (Sequence[int]): Output from which stages.
|
520 |
+
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
|
521 |
+
-1 means not freezing any parameters.
|
522 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
523 |
+
"""
|
524 |
+
|
525 |
+
def __init__(
|
526 |
+
self,
|
527 |
+
pretrain_img_size=224,
|
528 |
+
patch_size=4,
|
529 |
+
in_chans=3,
|
530 |
+
embed_dim=96,
|
531 |
+
depths=[2, 2, 6, 2],
|
532 |
+
num_heads=[3, 6, 12, 24],
|
533 |
+
window_size=7,
|
534 |
+
mlp_ratio=4.0,
|
535 |
+
qkv_bias=True,
|
536 |
+
qk_scale=None,
|
537 |
+
drop_rate=0.0,
|
538 |
+
attn_drop_rate=0.0,
|
539 |
+
drop_path_rate=0.2,
|
540 |
+
norm_layer=nn.LayerNorm,
|
541 |
+
ape=False,
|
542 |
+
patch_norm=True,
|
543 |
+
out_indices=(0, 1, 2, 3),
|
544 |
+
frozen_stages=-1,
|
545 |
+
use_checkpoint=False,
|
546 |
+
):
|
547 |
+
super().__init__()
|
548 |
+
|
549 |
+
self.pretrain_img_size = pretrain_img_size
|
550 |
+
self.num_layers = len(depths)
|
551 |
+
self.embed_dim = embed_dim
|
552 |
+
self.ape = ape
|
553 |
+
self.patch_norm = patch_norm
|
554 |
+
self.out_indices = out_indices
|
555 |
+
self.frozen_stages = frozen_stages
|
556 |
+
|
557 |
+
# split image into non-overlapping patches
|
558 |
+
self.patch_embed = PatchEmbed(
|
559 |
+
patch_size=patch_size,
|
560 |
+
in_chans=in_chans,
|
561 |
+
embed_dim=embed_dim,
|
562 |
+
norm_layer=norm_layer if self.patch_norm else None,
|
563 |
+
)
|
564 |
+
|
565 |
+
# absolute position embedding
|
566 |
+
if self.ape:
|
567 |
+
pretrain_img_size = to_2tuple(pretrain_img_size)
|
568 |
+
patch_size = to_2tuple(patch_size)
|
569 |
+
patches_resolution = [
|
570 |
+
pretrain_img_size[0] // patch_size[0],
|
571 |
+
pretrain_img_size[1] // patch_size[1],
|
572 |
+
]
|
573 |
+
|
574 |
+
self.absolute_pos_embed = nn.Parameter(
|
575 |
+
torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1])
|
576 |
+
)
|
577 |
+
trunc_normal_(self.absolute_pos_embed, std=0.02)
|
578 |
+
|
579 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
580 |
+
|
581 |
+
# stochastic depth
|
582 |
+
dpr = [
|
583 |
+
x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
|
584 |
+
] # stochastic depth decay rule
|
585 |
+
|
586 |
+
# build layers
|
587 |
+
self.layers = nn.ModuleList()
|
588 |
+
for i_layer in range(self.num_layers):
|
589 |
+
layer = BasicLayer(
|
590 |
+
dim=int(embed_dim * 2 ** i_layer),
|
591 |
+
depth=depths[i_layer],
|
592 |
+
num_heads=num_heads[i_layer],
|
593 |
+
window_size=window_size,
|
594 |
+
mlp_ratio=mlp_ratio,
|
595 |
+
qkv_bias=qkv_bias,
|
596 |
+
qk_scale=qk_scale,
|
597 |
+
drop=drop_rate,
|
598 |
+
attn_drop=attn_drop_rate,
|
599 |
+
drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])],
|
600 |
+
norm_layer=norm_layer,
|
601 |
+
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
|
602 |
+
use_checkpoint=use_checkpoint,
|
603 |
+
)
|
604 |
+
self.layers.append(layer)
|
605 |
+
|
606 |
+
num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
|
607 |
+
self.num_features = num_features
|
608 |
+
|
609 |
+
# add a norm layer for each output
|
610 |
+
for i_layer in out_indices:
|
611 |
+
layer = norm_layer(num_features[i_layer])
|
612 |
+
layer_name = f"norm{i_layer}"
|
613 |
+
self.add_module(layer_name, layer)
|
614 |
+
|
615 |
+
self._freeze_stages()
|
616 |
+
|
617 |
+
def _freeze_stages(self):
|
618 |
+
if self.frozen_stages >= 0:
|
619 |
+
self.patch_embed.eval()
|
620 |
+
for param in self.patch_embed.parameters():
|
621 |
+
param.requires_grad = False
|
622 |
+
|
623 |
+
if self.frozen_stages >= 1 and self.ape:
|
624 |
+
self.absolute_pos_embed.requires_grad = False
|
625 |
+
|
626 |
+
if self.frozen_stages >= 2:
|
627 |
+
self.pos_drop.eval()
|
628 |
+
for i in range(0, self.frozen_stages - 1):
|
629 |
+
m = self.layers[i]
|
630 |
+
m.eval()
|
631 |
+
for param in m.parameters():
|
632 |
+
param.requires_grad = False
|
633 |
+
|
634 |
+
def init_weights(self, pretrained=None):
|
635 |
+
"""Initialize the weights in backbone.
|
636 |
+
Args:
|
637 |
+
pretrained (str, optional): Path to pre-trained weights.
|
638 |
+
Defaults to None.
|
639 |
+
"""
|
640 |
+
|
641 |
+
def _init_weights(m):
|
642 |
+
if isinstance(m, nn.Linear):
|
643 |
+
trunc_normal_(m.weight, std=0.02)
|
644 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
645 |
+
nn.init.constant_(m.bias, 0)
|
646 |
+
elif isinstance(m, nn.LayerNorm):
|
647 |
+
nn.init.constant_(m.bias, 0)
|
648 |
+
nn.init.constant_(m.weight, 1.0)
|
649 |
+
|
650 |
+
if isinstance(pretrained, str):
|
651 |
+
self.apply(_init_weights)
|
652 |
+
checkpoint = torch.load(pretrained, map_location='cpu')
|
653 |
+
print(f'\nload pretrain weight from {pretrained} \n')
|
654 |
+
self.load_state_dict(checkpoint['model'], strict=False)
|
655 |
+
elif pretrained is None:
|
656 |
+
self.apply(_init_weights)
|
657 |
+
else:
|
658 |
+
raise TypeError('pretrained must be a str or None')
|
659 |
+
|
660 |
+
def forward(self, x):
|
661 |
+
"""Forward function."""
|
662 |
+
x = self.patch_embed(x)
|
663 |
+
|
664 |
+
Wh, Ww = x.size(2), x.size(3)
|
665 |
+
if self.ape:
|
666 |
+
# interpolate the position embedding to the corresponding size
|
667 |
+
absolute_pos_embed = F.interpolate(
|
668 |
+
self.absolute_pos_embed, size=(Wh, Ww), mode="bicubic"
|
669 |
+
)
|
670 |
+
x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C
|
671 |
+
else:
|
672 |
+
x = x.flatten(2).transpose(1, 2)
|
673 |
+
x = self.pos_drop(x)
|
674 |
+
|
675 |
+
outs = {}
|
676 |
+
for i in range(self.num_layers):
|
677 |
+
layer = self.layers[i]
|
678 |
+
x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
|
679 |
+
|
680 |
+
if i in self.out_indices:
|
681 |
+
norm_layer = getattr(self, f"norm{i}")
|
682 |
+
x_out = norm_layer(x_out)
|
683 |
+
|
684 |
+
out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
|
685 |
+
outs["res{}".format(i + 2)] = out
|
686 |
+
|
687 |
+
return outs
|
688 |
+
|
689 |
+
def train(self, mode=True):
|
690 |
+
"""Convert the model into training mode while keep layers freezed."""
|
691 |
+
super(SwinTransformer, self).train(mode)
|
692 |
+
self._freeze_stages()
|
693 |
+
|
694 |
+
|
695 |
+
@BACKBONE_REGISTRY.register()
|
696 |
+
class D2SwinTransformer(SwinTransformer, Backbone):
|
697 |
+
def __init__(self, cfg, input_shape):
|
698 |
+
|
699 |
+
pretrain_img_size = cfg.MODEL.SWIN.PRETRAIN_IMG_SIZE
|
700 |
+
patch_size = cfg.MODEL.SWIN.PATCH_SIZE
|
701 |
+
in_chans = 3
|
702 |
+
embed_dim = cfg.MODEL.SWIN.EMBED_DIM
|
703 |
+
depths = cfg.MODEL.SWIN.DEPTHS
|
704 |
+
num_heads = cfg.MODEL.SWIN.NUM_HEADS
|
705 |
+
window_size = cfg.MODEL.SWIN.WINDOW_SIZE
|
706 |
+
mlp_ratio = cfg.MODEL.SWIN.MLP_RATIO
|
707 |
+
qkv_bias = cfg.MODEL.SWIN.QKV_BIAS
|
708 |
+
qk_scale = cfg.MODEL.SWIN.QK_SCALE
|
709 |
+
drop_rate = cfg.MODEL.SWIN.DROP_RATE
|
710 |
+
attn_drop_rate = cfg.MODEL.SWIN.ATTN_DROP_RATE
|
711 |
+
drop_path_rate = cfg.MODEL.SWIN.DROP_PATH_RATE
|
712 |
+
norm_layer = nn.LayerNorm
|
713 |
+
ape = cfg.MODEL.SWIN.APE
|
714 |
+
patch_norm = cfg.MODEL.SWIN.PATCH_NORM
|
715 |
+
use_checkpoint = cfg.MODEL.SWIN.USE_CHECKPOINT
|
716 |
+
pretrained_weight = cfg.MODEL.SWIN.PRETRAINED_WEIGHT
|
717 |
+
|
718 |
+
|
719 |
+
super().__init__(
|
720 |
+
pretrain_img_size,
|
721 |
+
patch_size,
|
722 |
+
in_chans,
|
723 |
+
embed_dim,
|
724 |
+
depths,
|
725 |
+
num_heads,
|
726 |
+
window_size,
|
727 |
+
mlp_ratio,
|
728 |
+
qkv_bias,
|
729 |
+
qk_scale,
|
730 |
+
drop_rate,
|
731 |
+
attn_drop_rate,
|
732 |
+
drop_path_rate,
|
733 |
+
norm_layer,
|
734 |
+
ape,
|
735 |
+
patch_norm,
|
736 |
+
use_checkpoint=use_checkpoint,
|
737 |
+
)
|
738 |
+
self.init_weights(pretrained_weight)
|
739 |
+
self._out_features = cfg.MODEL.SWIN.OUT_FEATURES
|
740 |
+
|
741 |
+
self._out_feature_strides = {
|
742 |
+
"res2": 4,
|
743 |
+
"res3": 8,
|
744 |
+
"res4": 16,
|
745 |
+
"res5": 32,
|
746 |
+
}
|
747 |
+
self._out_feature_channels = {
|
748 |
+
"res2": self.num_features[0],
|
749 |
+
"res3": self.num_features[1],
|
750 |
+
"res4": self.num_features[2],
|
751 |
+
"res5": self.num_features[3],
|
752 |
+
}
|
753 |
+
|
754 |
+
def forward(self, x):
|
755 |
+
"""
|
756 |
+
Args:
|
757 |
+
x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``.
|
758 |
+
Returns:
|
759 |
+
dict[str->Tensor]: names and the corresponding features
|
760 |
+
"""
|
761 |
+
assert (
|
762 |
+
x.dim() == 4
|
763 |
+
), f"SwinTransformer takes an input of shape (N, C, H, W). Got {x.shape} instead!"
|
764 |
+
outputs = {}
|
765 |
+
y = super().forward(x)
|
766 |
+
for k in y.keys():
|
767 |
+
if k in self._out_features:
|
768 |
+
outputs[k] = y[k]
|
769 |
+
return outputs
|
770 |
+
|
771 |
+
def output_shape(self):
|
772 |
+
return {
|
773 |
+
name: ShapeSpec(
|
774 |
+
channels=self._out_feature_channels[name], stride=self._out_feature_strides[name]
|
775 |
+
)
|
776 |
+
for name in self._out_features
|
777 |
+
}
|
778 |
+
|
779 |
+
@property
|
780 |
+
def size_divisibility(self):
|
781 |
+
return 32
|
782 |
+
|
783 |
+
|
GLEE/glee/backbone/vit.py
ADDED
@@ -0,0 +1,472 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import math
|
3 |
+
import fvcore.nn.weight_init as weight_init
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
|
7 |
+
from detectron2.layers import CNNBlockBase, Conv2d, get_norm
|
8 |
+
from detectron2.modeling.backbone.fpn import _assert_strides_are_log2_contiguous
|
9 |
+
import torch.nn.functional as F
|
10 |
+
|
11 |
+
from detectron2.modeling import BACKBONE_REGISTRY, Backbone, ShapeSpec
|
12 |
+
from .utils import (
|
13 |
+
PatchEmbed,
|
14 |
+
add_decomposed_rel_pos,
|
15 |
+
get_abs_pos,
|
16 |
+
window_partition,
|
17 |
+
window_unpartition,
|
18 |
+
)
|
19 |
+
from functools import partial
|
20 |
+
import torch.utils.checkpoint as checkpoint
|
21 |
+
|
22 |
+
logger = logging.getLogger(__name__)
|
23 |
+
|
24 |
+
|
25 |
+
__all__ = ["ViT", "SimpleFeaturePyramid", "get_vit_lr_decay_rate"]
|
26 |
+
|
27 |
+
|
28 |
+
class Attention(nn.Module):
|
29 |
+
"""Multi-head Attention block with relative position embeddings."""
|
30 |
+
|
31 |
+
def __init__(
|
32 |
+
self,
|
33 |
+
dim,
|
34 |
+
num_heads=8,
|
35 |
+
qkv_bias=True,
|
36 |
+
use_rel_pos=False,
|
37 |
+
rel_pos_zero_init=True,
|
38 |
+
input_size=None,
|
39 |
+
):
|
40 |
+
"""
|
41 |
+
Args:
|
42 |
+
dim (int): Number of input channels.
|
43 |
+
num_heads (int): Number of attention heads.
|
44 |
+
qkv_bias (bool: If True, add a learnable bias to query, key, value.
|
45 |
+
rel_pos (bool): If True, add relative positional embeddings to the attention map.
|
46 |
+
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
|
47 |
+
input_size (int or None): Input resolution for calculating the relative positional
|
48 |
+
parameter size.
|
49 |
+
"""
|
50 |
+
super().__init__()
|
51 |
+
self.num_heads = num_heads
|
52 |
+
head_dim = dim // num_heads
|
53 |
+
self.scale = head_dim**-0.5
|
54 |
+
|
55 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
56 |
+
self.proj = nn.Linear(dim, dim)
|
57 |
+
|
58 |
+
self.use_rel_pos = use_rel_pos
|
59 |
+
if self.use_rel_pos:
|
60 |
+
# initialize relative positional embeddings
|
61 |
+
self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
|
62 |
+
self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
|
63 |
+
|
64 |
+
if not rel_pos_zero_init:
|
65 |
+
nn.init.trunc_normal_(self.rel_pos_h, std=0.02)
|
66 |
+
nn.init.trunc_normal_(self.rel_pos_w, std=0.02)
|
67 |
+
|
68 |
+
def forward(self, x):
|
69 |
+
B, H, W, _ = x.shape
|
70 |
+
# qkv with shape (3, B, nHead, H * W, C)
|
71 |
+
qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
72 |
+
# q, k, v with shape (B * nHead, H * W, C)
|
73 |
+
q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
|
74 |
+
|
75 |
+
with torch.backends.cuda.sdp_kernel(
|
76 |
+
enable_flash=True, enable_math=False, enable_mem_efficient=True
|
77 |
+
):
|
78 |
+
x = F.scaled_dot_product_attention(q, k, v)
|
79 |
+
attn = (q * self.scale) @ k.transpose(-2, -1)
|
80 |
+
|
81 |
+
if self.use_rel_pos:
|
82 |
+
attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
|
83 |
+
|
84 |
+
attn = attn.softmax(dim=-1)
|
85 |
+
x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
|
86 |
+
x = self.proj(x)
|
87 |
+
|
88 |
+
return x
|
89 |
+
|
90 |
+
|
91 |
+
class ResBottleneckBlock(CNNBlockBase):
|
92 |
+
"""
|
93 |
+
The standard bottleneck residual block without the last activation layer.
|
94 |
+
It contains 3 conv layers with kernels 1x1, 3x3, 1x1.
|
95 |
+
"""
|
96 |
+
|
97 |
+
def __init__(
|
98 |
+
self,
|
99 |
+
in_channels,
|
100 |
+
out_channels,
|
101 |
+
bottleneck_channels,
|
102 |
+
norm="LN",
|
103 |
+
act_layer=nn.GELU,
|
104 |
+
):
|
105 |
+
"""
|
106 |
+
Args:
|
107 |
+
in_channels (int): Number of input channels.
|
108 |
+
out_channels (int): Number of output channels.
|
109 |
+
bottleneck_channels (int): number of output channels for the 3x3
|
110 |
+
"bottleneck" conv layers.
|
111 |
+
norm (str or callable): normalization for all conv layers.
|
112 |
+
See :func:`layers.get_norm` for supported format.
|
113 |
+
act_layer (callable): activation for all conv layers.
|
114 |
+
"""
|
115 |
+
super().__init__(in_channels, out_channels, 1)
|
116 |
+
|
117 |
+
self.conv1 = Conv2d(in_channels, bottleneck_channels, 1, bias=False)
|
118 |
+
self.norm1 = get_norm(norm, bottleneck_channels)
|
119 |
+
self.act1 = act_layer()
|
120 |
+
|
121 |
+
self.conv2 = Conv2d(
|
122 |
+
bottleneck_channels,
|
123 |
+
bottleneck_channels,
|
124 |
+
3,
|
125 |
+
padding=1,
|
126 |
+
bias=False,
|
127 |
+
)
|
128 |
+
self.norm2 = get_norm(norm, bottleneck_channels)
|
129 |
+
self.act2 = act_layer()
|
130 |
+
|
131 |
+
self.conv3 = Conv2d(bottleneck_channels, out_channels, 1, bias=False)
|
132 |
+
self.norm3 = get_norm(norm, out_channels)
|
133 |
+
|
134 |
+
for layer in [self.conv1, self.conv2, self.conv3]:
|
135 |
+
weight_init.c2_msra_fill(layer)
|
136 |
+
for layer in [self.norm1, self.norm2]:
|
137 |
+
layer.weight.data.fill_(1.0)
|
138 |
+
layer.bias.data.zero_()
|
139 |
+
# zero init last norm layer.
|
140 |
+
self.norm3.weight.data.zero_()
|
141 |
+
self.norm3.bias.data.zero_()
|
142 |
+
|
143 |
+
def forward(self, x):
|
144 |
+
out = x
|
145 |
+
for layer in self.children():
|
146 |
+
out = layer(out)
|
147 |
+
|
148 |
+
out = x + out
|
149 |
+
return out
|
150 |
+
|
151 |
+
|
152 |
+
class Block(nn.Module):
|
153 |
+
"""Transformer blocks with support of window attention and residual propagation blocks"""
|
154 |
+
|
155 |
+
def __init__(
|
156 |
+
self,
|
157 |
+
dim,
|
158 |
+
num_heads,
|
159 |
+
mlp_ratio=4.0,
|
160 |
+
qkv_bias=True,
|
161 |
+
drop_path=0.0,
|
162 |
+
norm_layer=nn.LayerNorm,
|
163 |
+
act_layer=nn.GELU,
|
164 |
+
use_rel_pos=False,
|
165 |
+
rel_pos_zero_init=True,
|
166 |
+
window_size=0,
|
167 |
+
use_residual_block=False,
|
168 |
+
input_size=None,
|
169 |
+
):
|
170 |
+
"""
|
171 |
+
Args:
|
172 |
+
dim (int): Number of input channels.
|
173 |
+
num_heads (int): Number of attention heads in each ViT block.
|
174 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
175 |
+
qkv_bias (bool): If True, add a learnable bias to query, key, value.
|
176 |
+
drop_path (float): Stochastic depth rate.
|
177 |
+
norm_layer (nn.Module): Normalization layer.
|
178 |
+
act_layer (nn.Module): Activation layer.
|
179 |
+
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
|
180 |
+
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
|
181 |
+
window_size (int): Window size for window attention blocks. If it equals 0, then not
|
182 |
+
use window attention.
|
183 |
+
use_residual_block (bool): If True, use a residual block after the MLP block.
|
184 |
+
input_size (int or None): Input resolution for calculating the relative positional
|
185 |
+
parameter size.
|
186 |
+
"""
|
187 |
+
super().__init__()
|
188 |
+
self.norm1 = norm_layer(dim)
|
189 |
+
self.attn = Attention(
|
190 |
+
dim,
|
191 |
+
num_heads=num_heads,
|
192 |
+
qkv_bias=qkv_bias,
|
193 |
+
use_rel_pos=use_rel_pos,
|
194 |
+
rel_pos_zero_init=rel_pos_zero_init,
|
195 |
+
input_size=input_size if window_size == 0 else (window_size, window_size),
|
196 |
+
)
|
197 |
+
|
198 |
+
from timm.models.layers import DropPath, Mlp
|
199 |
+
|
200 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
201 |
+
self.norm2 = norm_layer(dim)
|
202 |
+
self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer)
|
203 |
+
|
204 |
+
self.window_size = window_size
|
205 |
+
|
206 |
+
self.use_residual_block = use_residual_block
|
207 |
+
if use_residual_block:
|
208 |
+
# Use a residual block with bottleneck channel as dim // 2
|
209 |
+
self.residual = ResBottleneckBlock(
|
210 |
+
in_channels=dim,
|
211 |
+
out_channels=dim,
|
212 |
+
bottleneck_channels=dim // 2,
|
213 |
+
norm="LN",
|
214 |
+
act_layer=act_layer,
|
215 |
+
)
|
216 |
+
|
217 |
+
def forward(self, x):
|
218 |
+
shortcut = x
|
219 |
+
x = self.norm1(x)
|
220 |
+
# Window partition
|
221 |
+
if self.window_size > 0:
|
222 |
+
H, W = x.shape[1], x.shape[2]
|
223 |
+
x, pad_hw = window_partition(x, self.window_size)
|
224 |
+
x = self.attn(x)
|
225 |
+
# Reverse window partition
|
226 |
+
if self.window_size > 0:
|
227 |
+
x = window_unpartition(x, self.window_size, pad_hw, (H, W))
|
228 |
+
|
229 |
+
x = shortcut + self.drop_path(x)
|
230 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
231 |
+
|
232 |
+
if self.use_residual_block:
|
233 |
+
x = self.residual(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
|
234 |
+
|
235 |
+
return x
|
236 |
+
|
237 |
+
|
238 |
+
class ViT(Backbone):
|
239 |
+
"""
|
240 |
+
This module implements Vision Transformer (ViT) backbone in :paper:`vitdet`.
|
241 |
+
"Exploring Plain Vision Transformer Backbones for Object Detection",
|
242 |
+
https://arxiv.org/abs/2203.16527
|
243 |
+
"""
|
244 |
+
|
245 |
+
def __init__(
|
246 |
+
self,
|
247 |
+
img_size=1024,
|
248 |
+
patch_size=16,
|
249 |
+
in_chans=3,
|
250 |
+
embed_dim=768,
|
251 |
+
depth=12,
|
252 |
+
num_heads=12,
|
253 |
+
mlp_ratio=4.0,
|
254 |
+
qkv_bias=True,
|
255 |
+
drop_path_rate=0.0,
|
256 |
+
norm_layer=nn.LayerNorm,
|
257 |
+
act_layer=nn.GELU,
|
258 |
+
use_abs_pos=True,
|
259 |
+
use_rel_pos=False,
|
260 |
+
rel_pos_zero_init=True,
|
261 |
+
window_size=0,
|
262 |
+
window_block_indexes=(),
|
263 |
+
residual_block_indexes=(),
|
264 |
+
use_act_checkpoint=False,
|
265 |
+
pretrain_img_size=224,
|
266 |
+
pretrain_use_cls_token=True,
|
267 |
+
out_feature="last_feat",
|
268 |
+
):
|
269 |
+
"""
|
270 |
+
Args:
|
271 |
+
img_size (int): Input image size.
|
272 |
+
patch_size (int): Patch size.
|
273 |
+
in_chans (int): Number of input image channels.
|
274 |
+
embed_dim (int): Patch embedding dimension.
|
275 |
+
depth (int): Depth of ViT.
|
276 |
+
num_heads (int): Number of attention heads in each ViT block.
|
277 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
278 |
+
qkv_bias (bool): If True, add a learnable bias to query, key, value.
|
279 |
+
drop_path_rate (float): Stochastic depth rate.
|
280 |
+
norm_layer (nn.Module): Normalization layer.
|
281 |
+
act_layer (nn.Module): Activation layer.
|
282 |
+
use_abs_pos (bool): If True, use absolute positional embeddings.
|
283 |
+
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
|
284 |
+
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
|
285 |
+
window_size (int): Window size for window attention blocks.
|
286 |
+
window_block_indexes (list): Indexes for blocks using window attention.
|
287 |
+
residual_block_indexes (list): Indexes for blocks using conv propagation.
|
288 |
+
use_act_checkpoint (bool): If True, use activation checkpointing.
|
289 |
+
pretrain_img_size (int): input image size for pretraining models.
|
290 |
+
pretrain_use_cls_token (bool): If True, pretrainig models use class token.
|
291 |
+
out_feature (str): name of the feature from the last block.
|
292 |
+
"""
|
293 |
+
super().__init__()
|
294 |
+
self.pretrain_use_cls_token = pretrain_use_cls_token
|
295 |
+
|
296 |
+
self.patch_embed = PatchEmbed(
|
297 |
+
kernel_size=(patch_size, patch_size),
|
298 |
+
stride=(patch_size, patch_size),
|
299 |
+
in_chans=in_chans,
|
300 |
+
embed_dim=embed_dim,
|
301 |
+
)
|
302 |
+
|
303 |
+
if use_abs_pos:
|
304 |
+
# Initialize absolute positional embedding with pretrain image size.
|
305 |
+
num_patches = (pretrain_img_size // patch_size) * (pretrain_img_size // patch_size)
|
306 |
+
num_positions = (num_patches + 1) if pretrain_use_cls_token else num_patches
|
307 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_positions, embed_dim))
|
308 |
+
else:
|
309 |
+
self.pos_embed = None
|
310 |
+
|
311 |
+
# stochastic depth decay rule
|
312 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
|
313 |
+
|
314 |
+
self.blocks = nn.ModuleList()
|
315 |
+
for i in range(depth):
|
316 |
+
block = Block(
|
317 |
+
dim=embed_dim,
|
318 |
+
num_heads=num_heads,
|
319 |
+
mlp_ratio=mlp_ratio,
|
320 |
+
qkv_bias=qkv_bias,
|
321 |
+
drop_path=dpr[i],
|
322 |
+
norm_layer=norm_layer,
|
323 |
+
act_layer=act_layer,
|
324 |
+
use_rel_pos=use_rel_pos,
|
325 |
+
rel_pos_zero_init=rel_pos_zero_init,
|
326 |
+
window_size=window_size if i in window_block_indexes else 0,
|
327 |
+
use_residual_block=i in residual_block_indexes,
|
328 |
+
input_size=(img_size // patch_size, img_size // patch_size),
|
329 |
+
)
|
330 |
+
if use_act_checkpoint:
|
331 |
+
# TODO: use torch.utils.checkpoint
|
332 |
+
from fairscale.nn.checkpoint import checkpoint_wrapper
|
333 |
+
|
334 |
+
block = checkpoint_wrapper(block)
|
335 |
+
self.blocks.append(block)
|
336 |
+
|
337 |
+
self._out_feature_channels = {out_feature: embed_dim}
|
338 |
+
self._out_feature_strides = {out_feature: patch_size}
|
339 |
+
self._out_features = [out_feature]
|
340 |
+
|
341 |
+
if self.pos_embed is not None:
|
342 |
+
nn.init.trunc_normal_(self.pos_embed, std=0.02)
|
343 |
+
|
344 |
+
# In our method, we don't use backbone feature with stride 4
|
345 |
+
self.fpn1 = nn.Sequential(
|
346 |
+
nn.ConvTranspose2d(embed_dim, embed_dim // 2, kernel_size=2, stride=2),
|
347 |
+
)
|
348 |
+
self.fpn2 = nn.Identity()
|
349 |
+
self.fpn3 = nn.MaxPool2d(kernel_size=2, stride=2)
|
350 |
+
|
351 |
+
self.apply(self._init_weights)
|
352 |
+
|
353 |
+
def _init_weights(self, m):
|
354 |
+
if isinstance(m, nn.Linear):
|
355 |
+
nn.init.trunc_normal_(m.weight, std=0.02)
|
356 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
357 |
+
nn.init.constant_(m.bias, 0)
|
358 |
+
elif isinstance(m, nn.LayerNorm):
|
359 |
+
nn.init.constant_(m.bias, 0)
|
360 |
+
nn.init.constant_(m.weight, 1.0)
|
361 |
+
|
362 |
+
def forward(self, x):
|
363 |
+
x = self.patch_embed(x)
|
364 |
+
if self.pos_embed is not None:
|
365 |
+
x = x + get_abs_pos(
|
366 |
+
self.pos_embed, self.pretrain_use_cls_token, (x.shape[1], x.shape[2])
|
367 |
+
)
|
368 |
+
|
369 |
+
for blk in self.blocks:
|
370 |
+
x = blk(x)
|
371 |
+
xp = x.permute(0, 3, 1, 2) # (b, h, w, c) --> (b, c, h, w)
|
372 |
+
|
373 |
+
features = []
|
374 |
+
ops = [self.fpn1, self.fpn2, self.fpn3]
|
375 |
+
for i in range(len(ops)):
|
376 |
+
features.append(ops[i](xp))
|
377 |
+
rets = {"res{}".format(u + 3): v for (u,v) in enumerate(features)}
|
378 |
+
|
379 |
+
return rets
|
380 |
+
|
381 |
+
|
382 |
+
|
383 |
+
@BACKBONE_REGISTRY.register()
|
384 |
+
class D2ViT(ViT, Backbone):
|
385 |
+
def __init__(self, cfg, input_shape):
|
386 |
+
use_checkpoint = cfg.MODEL.VIT.USE_CHECKPOINT
|
387 |
+
if cfg.MODEL.VIT.NAME == "ViT-Base":
|
388 |
+
embed_dim=768
|
389 |
+
depth=12
|
390 |
+
drop_path_rate=0.1
|
391 |
+
num_heads=12
|
392 |
+
elif cfg.MODEL.VIT.NAME == "ViT-Large":
|
393 |
+
embed_dim=1024
|
394 |
+
depth=24
|
395 |
+
drop_path_rate=0.4
|
396 |
+
num_heads=16
|
397 |
+
elif cfg.MODEL.VIT.NAME == "ViT-huge":
|
398 |
+
embed_dim=1280
|
399 |
+
depth=32
|
400 |
+
drop_path_rate=0.5
|
401 |
+
num_heads=16
|
402 |
+
else:
|
403 |
+
raise ValueError("Unsupported ViT name")
|
404 |
+
super().__init__(
|
405 |
+
img_size=1024,
|
406 |
+
patch_size=16,
|
407 |
+
in_chans=input_shape.channels,
|
408 |
+
embed_dim=embed_dim,
|
409 |
+
depth=depth,
|
410 |
+
num_heads=num_heads,
|
411 |
+
drop_path_rate=drop_path_rate,
|
412 |
+
window_size=14,
|
413 |
+
mlp_ratio=4,
|
414 |
+
qkv_bias=True,
|
415 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
416 |
+
window_block_indexes=[
|
417 |
+
# 2, 5, 8 11 for global attention
|
418 |
+
0,
|
419 |
+
1,
|
420 |
+
3,
|
421 |
+
4,
|
422 |
+
6,
|
423 |
+
7,
|
424 |
+
9,
|
425 |
+
10,
|
426 |
+
],
|
427 |
+
residual_block_indexes=[],
|
428 |
+
use_rel_pos=True,
|
429 |
+
out_feature="last_feat",
|
430 |
+
use_act_checkpoint=use_checkpoint)
|
431 |
+
|
432 |
+
self._out_features = cfg.MODEL.VIT.OUT_FEATURES
|
433 |
+
|
434 |
+
self._out_feature_strides = {
|
435 |
+
"res3": 8,
|
436 |
+
"res4": 16,
|
437 |
+
"res5": 32,
|
438 |
+
}
|
439 |
+
self._out_feature_channels = {
|
440 |
+
"res3": embed_dim // 2,
|
441 |
+
"res4": embed_dim,
|
442 |
+
"res5": embed_dim,
|
443 |
+
}
|
444 |
+
|
445 |
+
def forward(self, x):
|
446 |
+
"""
|
447 |
+
Args:
|
448 |
+
x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``.
|
449 |
+
Returns:
|
450 |
+
dict[str->Tensor]: names and the corresponding features
|
451 |
+
"""
|
452 |
+
assert (
|
453 |
+
x.dim() == 4
|
454 |
+
), f"SwinTransformer takes an input of shape (N, C, H, W). Got {x.shape} instead!"
|
455 |
+
outputs = {}
|
456 |
+
y = super().forward(x)
|
457 |
+
for k in y.keys():
|
458 |
+
if k in self._out_features:
|
459 |
+
outputs[k] = y[k]
|
460 |
+
return outputs
|
461 |
+
|
462 |
+
def output_shape(self):
|
463 |
+
return {
|
464 |
+
name: ShapeSpec(
|
465 |
+
channels=self._out_feature_channels[name], stride=self._out_feature_strides[name]
|
466 |
+
)
|
467 |
+
for name in self._out_features
|
468 |
+
}
|
469 |
+
|
470 |
+
@property
|
471 |
+
def size_divisibility(self):
|
472 |
+
return 32
|
GLEE/glee/backbone/vit_utils.py
ADDED
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
2 |
+
import math
|
3 |
+
import numpy as np
|
4 |
+
from scipy import interpolate
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
|
9 |
+
__all__ = [
|
10 |
+
"window_partition",
|
11 |
+
"window_unpartition",
|
12 |
+
"add_decomposed_rel_pos",
|
13 |
+
"get_abs_pos",
|
14 |
+
"PatchEmbed",
|
15 |
+
]
|
16 |
+
|
17 |
+
|
18 |
+
def window_partition(x, window_size):
|
19 |
+
"""
|
20 |
+
Partition into non-overlapping windows with padding if needed.
|
21 |
+
Args:
|
22 |
+
x (tensor): input tokens with [B, H, W, C].
|
23 |
+
window_size (int): window size.
|
24 |
+
|
25 |
+
Returns:
|
26 |
+
windows: windows after partition with [B * num_windows, window_size, window_size, C].
|
27 |
+
(Hp, Wp): padded height and width before partition
|
28 |
+
"""
|
29 |
+
B, H, W, C = x.shape
|
30 |
+
|
31 |
+
pad_h = (window_size - H % window_size) % window_size
|
32 |
+
pad_w = (window_size - W % window_size) % window_size
|
33 |
+
if pad_h > 0 or pad_w > 0:
|
34 |
+
x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
|
35 |
+
Hp, Wp = H + pad_h, W + pad_w
|
36 |
+
|
37 |
+
x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
|
38 |
+
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
39 |
+
return windows, (Hp, Wp)
|
40 |
+
|
41 |
+
|
42 |
+
def window_unpartition(windows, window_size, pad_hw, hw):
|
43 |
+
"""
|
44 |
+
Window unpartition into original sequences and removing padding.
|
45 |
+
Args:
|
46 |
+
x (tensor): input tokens with [B * num_windows, window_size, window_size, C].
|
47 |
+
window_size (int): window size.
|
48 |
+
pad_hw (Tuple): padded height and width (Hp, Wp).
|
49 |
+
hw (Tuple): original height and width (H, W) before padding.
|
50 |
+
|
51 |
+
Returns:
|
52 |
+
x: unpartitioned sequences with [B, H, W, C].
|
53 |
+
"""
|
54 |
+
Hp, Wp = pad_hw
|
55 |
+
H, W = hw
|
56 |
+
B = windows.shape[0] // (Hp * Wp // window_size // window_size)
|
57 |
+
x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
|
58 |
+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
|
59 |
+
|
60 |
+
if Hp > H or Wp > W:
|
61 |
+
x = x[:, :H, :W, :].contiguous()
|
62 |
+
return x
|
63 |
+
|
64 |
+
|
65 |
+
def get_rel_pos(q_size, k_size, rel_pos, interp_type):
|
66 |
+
"""
|
67 |
+
Get relative positional embeddings according to the relative positions of
|
68 |
+
query and key sizes.
|
69 |
+
Args:
|
70 |
+
q_size (int): size of query q.
|
71 |
+
k_size (int): size of key k.
|
72 |
+
rel_pos (Tensor): relative position embeddings (L, C).
|
73 |
+
|
74 |
+
Returns:
|
75 |
+
Extracted positional embeddings according to relative positions.
|
76 |
+
"""
|
77 |
+
max_rel_dist = int(2 * max(q_size, k_size) - 1)
|
78 |
+
# Interpolate rel pos if needed.
|
79 |
+
if rel_pos.shape[0] != max_rel_dist:
|
80 |
+
if interp_type == "vitdet":
|
81 |
+
# the vitdet impl:
|
82 |
+
# https://github.com/facebookresearch/detectron2/blob/96c752ce821a3340e27edd51c28a00665dd32a30/detectron2/modeling/backbone/utils.py#L77.
|
83 |
+
|
84 |
+
rel_pos_resized = F.interpolate(
|
85 |
+
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
|
86 |
+
size=max_rel_dist,
|
87 |
+
mode="linear",
|
88 |
+
)
|
89 |
+
rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
|
90 |
+
elif interp_type == "beit":
|
91 |
+
# steal from beit https://github.com/microsoft/unilm/tree/master/beit
|
92 |
+
# modified by Yuxin Fang
|
93 |
+
|
94 |
+
src_size = rel_pos.shape[0]
|
95 |
+
dst_size = max_rel_dist
|
96 |
+
|
97 |
+
q = 1.0903078
|
98 |
+
dis = []
|
99 |
+
|
100 |
+
cur = 1
|
101 |
+
for i in range(src_size // 2):
|
102 |
+
dis.append(cur)
|
103 |
+
cur += q ** (i + 1)
|
104 |
+
|
105 |
+
r_ids = [-_ for _ in reversed(dis)]
|
106 |
+
x = r_ids + [0] + dis
|
107 |
+
t = dst_size // 2.0
|
108 |
+
dx = np.arange(-t, t + 0.1, 1.0)
|
109 |
+
|
110 |
+
all_rel_pos_bias = []
|
111 |
+
for i in range(rel_pos.shape[1]):
|
112 |
+
# a hack from https://github.com/baaivision/EVA/issues/8,
|
113 |
+
# could also be used in fine-tuning but the performance haven't been tested.
|
114 |
+
z = rel_pos[:, i].view(src_size).cpu().float().detach().numpy()
|
115 |
+
f = interpolate.interp1d(x, z, kind='cubic', fill_value="extrapolate")
|
116 |
+
all_rel_pos_bias.append(
|
117 |
+
torch.Tensor(f(dx)).contiguous().view(-1, 1).to(rel_pos.device))
|
118 |
+
rel_pos_resized = torch.cat(all_rel_pos_bias, dim=-1)
|
119 |
+
else:
|
120 |
+
raise NotImplementedError()
|
121 |
+
else:
|
122 |
+
rel_pos_resized = rel_pos
|
123 |
+
|
124 |
+
# Scale the coords with short length if shapes for q and k are different.
|
125 |
+
q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
|
126 |
+
k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
|
127 |
+
relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
|
128 |
+
|
129 |
+
return rel_pos_resized[relative_coords.long()]
|
130 |
+
|
131 |
+
|
132 |
+
def add_decomposed_rel_pos(attn, q, rel_pos_h, rel_pos_w, q_size, k_size, interp_type):
|
133 |
+
"""
|
134 |
+
Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
|
135 |
+
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
|
136 |
+
Args:
|
137 |
+
attn (Tensor): attention map.
|
138 |
+
q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
|
139 |
+
rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
|
140 |
+
rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
|
141 |
+
q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
|
142 |
+
k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
|
143 |
+
|
144 |
+
Returns:
|
145 |
+
attn (Tensor): attention map with added relative positional embeddings.
|
146 |
+
"""
|
147 |
+
q_h, q_w = q_size
|
148 |
+
k_h, k_w = k_size
|
149 |
+
Rh = get_rel_pos(q_h, k_h, rel_pos_h, interp_type)
|
150 |
+
Rw = get_rel_pos(q_w, k_w, rel_pos_w, interp_type)
|
151 |
+
|
152 |
+
B, _, dim = q.shape
|
153 |
+
r_q = q.reshape(B, q_h, q_w, dim)
|
154 |
+
rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
|
155 |
+
rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
|
156 |
+
|
157 |
+
attn = (
|
158 |
+
attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
|
159 |
+
).view(B, q_h * q_w, k_h * k_w)
|
160 |
+
|
161 |
+
return attn
|
162 |
+
|
163 |
+
|
164 |
+
def get_abs_pos(abs_pos, has_cls_token, hw):
|
165 |
+
"""
|
166 |
+
Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token
|
167 |
+
dimension for the original embeddings.
|
168 |
+
Args:
|
169 |
+
abs_pos (Tensor): absolute positional embeddings with (1, num_position, C).
|
170 |
+
has_cls_token (bool): If true, has 1 embedding in abs_pos for cls token.
|
171 |
+
hw (Tuple): size of input image tokens.
|
172 |
+
|
173 |
+
Returns:
|
174 |
+
Absolute positional embeddings after processing with shape (1, H, W, C)
|
175 |
+
"""
|
176 |
+
h, w = hw
|
177 |
+
if has_cls_token:
|
178 |
+
abs_pos = abs_pos[:, 1:]
|
179 |
+
xy_num = abs_pos.shape[1]
|
180 |
+
size = int(math.sqrt(xy_num))
|
181 |
+
assert size * size == xy_num
|
182 |
+
|
183 |
+
if size != h or size != w:
|
184 |
+
new_abs_pos = F.interpolate(
|
185 |
+
abs_pos.reshape(1, size, size, -1).permute(0, 3, 1, 2),
|
186 |
+
size=(h, w),
|
187 |
+
mode="bicubic",
|
188 |
+
align_corners=False,
|
189 |
+
)
|
190 |
+
|
191 |
+
return new_abs_pos.permute(0, 2, 3, 1)
|
192 |
+
else:
|
193 |
+
return abs_pos.reshape(1, h, w, -1)
|
194 |
+
|
195 |
+
|
196 |
+
class PatchEmbed(nn.Module):
|
197 |
+
"""
|
198 |
+
Image to Patch Embedding.
|
199 |
+
"""
|
200 |
+
|
201 |
+
def __init__(
|
202 |
+
self, kernel_size=(16, 16), stride=(16, 16), padding=(0, 0), in_chans=3, embed_dim=768
|
203 |
+
):
|
204 |
+
"""
|
205 |
+
Args:
|
206 |
+
kernel_size (Tuple): kernel size of the projection layer.
|
207 |
+
stride (Tuple): stride of the projection layer.
|
208 |
+
padding (Tuple): padding size of the projection layer.
|
209 |
+
in_chans (int): Number of input image channels.
|
210 |
+
embed_dim (int): embed_dim (int): Patch embedding dimension.
|
211 |
+
"""
|
212 |
+
super().__init__()
|
213 |
+
|
214 |
+
self.proj = nn.Conv2d(
|
215 |
+
in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
|
216 |
+
)
|
217 |
+
|
218 |
+
def forward(self, x):
|
219 |
+
x = self.proj(x)
|
220 |
+
# B C H W -> B H W C
|
221 |
+
x = x.permute(0, 2, 3, 1)
|
222 |
+
return x
|
GLEE/glee/config.py
ADDED
@@ -0,0 +1,387 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
from detectron2.config import CfgNode as CN
|
3 |
+
|
4 |
+
|
5 |
+
def add_glee_config(cfg):
|
6 |
+
"""
|
7 |
+
Add config for DETR.
|
8 |
+
"""
|
9 |
+
|
10 |
+
cfg.FIND_UNUSED_PARAMETERS = True
|
11 |
+
cfg.MODEL.MAX_CATEGORY_LEN = 100
|
12 |
+
cfg.MODEL.PSEUDO_VIDEO = False
|
13 |
+
cfg.MODEL.FREEZE_WHOLE = False
|
14 |
+
cfg.MODEL.CONTRAS_MEAN = False
|
15 |
+
cfg.MODEL.CROSS_TRACK = False
|
16 |
+
cfg.MODEL.TRACK_VERSION = 'v3'
|
17 |
+
|
18 |
+
cfg.INPUT.SAMPLING_FRAME_NUM = 1
|
19 |
+
cfg.INPUT.SAMPLING_FRAME_RANGE = 10
|
20 |
+
cfg.INPUT.SAMPLING_INTERVAL = 1
|
21 |
+
cfg.INPUT.SAMPLING_FRAME_SHUFFLE = False
|
22 |
+
cfg.INPUT.AUGMENTATIONS = [] # "brightness", "contrast", "saturation", "rotation"
|
23 |
+
cfg.INPUT.DATASET_MAPPER_NAME = None
|
24 |
+
|
25 |
+
cfg.DATALOADER.DATASET_RATIO = [1, 1]
|
26 |
+
cfg.DATALOADER.USE_DIFF_BS_SIZE = True
|
27 |
+
cfg.DATALOADER.DATASET_BS = [2, 2]
|
28 |
+
cfg.DATALOADER.DATASET_FILTERS = [True, True]
|
29 |
+
cfg.DATALOADER.USE_RFS = [False, False]
|
30 |
+
cfg.DATALOADER.MULTI_DATASET_GROUPING = True
|
31 |
+
cfg.DATALOADER.DATASET_ANN = ['image']
|
32 |
+
|
33 |
+
|
34 |
+
cfg.INPUT.SIZE_DIVISIBILITY = -1
|
35 |
+
|
36 |
+
cfg.DATALOADER.DATASET_RATIO = [1, 1]
|
37 |
+
cfg.DATALOADER.USE_DIFF_BS_SIZE = True
|
38 |
+
cfg.DATALOADER.DATASET_BS = [2, 2]
|
39 |
+
cfg.DATALOADER.USE_RFS = [False, False]
|
40 |
+
cfg.DATALOADER.MULTI_DATASET_GROUPING = True
|
41 |
+
cfg.DATALOADER.DATASET_ANN = ['box', 'box']
|
42 |
+
|
43 |
+
# Allow different datasets to use different input resolutions
|
44 |
+
cfg.INPUT.MIN_SIZE_TRAIN_MULTI = [(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), (320, 352, 392, 416, 448, 480, 512, 544, 576, 608, 640)]
|
45 |
+
cfg.INPUT.MAX_SIZE_TRAIN_MULTI = [1333, 768]
|
46 |
+
|
47 |
+
|
48 |
+
# MaskDINO model config
|
49 |
+
cfg.MODEL.MaskDINO = CN()
|
50 |
+
cfg.MODEL.MaskDINO.LEARN_TGT = False
|
51 |
+
|
52 |
+
# loss
|
53 |
+
cfg.MODEL.MaskDINO.PANO_BOX_LOSS = False
|
54 |
+
cfg.MODEL.MaskDINO.SEMANTIC_CE_LOSS = False
|
55 |
+
cfg.MODEL.MaskDINO.DEEP_SUPERVISION = True
|
56 |
+
cfg.MODEL.MaskDINO.NO_OBJECT_WEIGHT = 0.1
|
57 |
+
cfg.MODEL.MaskDINO.CLASS_WEIGHT = 4.0
|
58 |
+
cfg.MODEL.MaskDINO.DICE_WEIGHT = 5.0
|
59 |
+
cfg.MODEL.MaskDINO.MASK_WEIGHT = 5.0
|
60 |
+
cfg.MODEL.MaskDINO.BOX_WEIGHT = 5.
|
61 |
+
cfg.MODEL.MaskDINO.GIOU_WEIGHT = 2.
|
62 |
+
|
63 |
+
# cost weight
|
64 |
+
cfg.MODEL.MaskDINO.COST_CLASS_WEIGHT = 4.0
|
65 |
+
cfg.MODEL.MaskDINO.COST_DICE_WEIGHT = 5.0
|
66 |
+
cfg.MODEL.MaskDINO.COST_MASK_WEIGHT = 5.0
|
67 |
+
cfg.MODEL.MaskDINO.COST_BOX_WEIGHT = 5.
|
68 |
+
cfg.MODEL.MaskDINO.COST_GIOU_WEIGHT = 2.
|
69 |
+
|
70 |
+
# transformer config
|
71 |
+
cfg.MODEL.MaskDINO.NHEADS = 8
|
72 |
+
cfg.MODEL.MaskDINO.DROPOUT = 0.1
|
73 |
+
cfg.MODEL.MaskDINO.DIM_FEEDFORWARD = 2048
|
74 |
+
cfg.MODEL.MaskDINO.ENC_LAYERS = 0
|
75 |
+
cfg.MODEL.MaskDINO.DEC_LAYERS = 6
|
76 |
+
cfg.MODEL.MaskDINO.INITIAL_PRED = True
|
77 |
+
cfg.MODEL.MaskDINO.PRE_NORM = False
|
78 |
+
cfg.MODEL.MaskDINO.BOX_LOSS = True
|
79 |
+
cfg.MODEL.MaskDINO.HIDDEN_DIM = 256
|
80 |
+
cfg.MODEL.MaskDINO.NUM_OBJECT_QUERIES = 100
|
81 |
+
|
82 |
+
cfg.MODEL.MaskDINO.ENFORCE_INPUT_PROJ = False
|
83 |
+
cfg.MODEL.MaskDINO.TWO_STAGE = True
|
84 |
+
cfg.MODEL.MaskDINO.INITIALIZE_BOX_TYPE = 'no' # ['no', 'bitmask', 'mask2box']
|
85 |
+
cfg.MODEL.MaskDINO.DN="seg"
|
86 |
+
cfg.MODEL.MaskDINO.DN_NOISE_SCALE=0.4
|
87 |
+
cfg.MODEL.MaskDINO.DN_NUM=100
|
88 |
+
cfg.MODEL.MaskDINO.PRED_CONV=False
|
89 |
+
|
90 |
+
cfg.MODEL.MaskDINO.EVAL_FLAG = 1
|
91 |
+
|
92 |
+
# MSDeformAttn encoder configs
|
93 |
+
cfg.MODEL.SEM_SEG_HEAD.DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES = ["res3", "res4", "res5"]
|
94 |
+
cfg.MODEL.SEM_SEG_HEAD.DEFORMABLE_TRANSFORMER_ENCODER_N_POINTS = 4
|
95 |
+
cfg.MODEL.SEM_SEG_HEAD.DEFORMABLE_TRANSFORMER_ENCODER_N_HEADS = 8
|
96 |
+
cfg.MODEL.SEM_SEG_HEAD.DIM_FEEDFORWARD = 2048
|
97 |
+
cfg.MODEL.SEM_SEG_HEAD.NUM_FEATURE_LEVELS = 3
|
98 |
+
cfg.MODEL.SEM_SEG_HEAD.TOTAL_NUM_FEATURE_LEVELS = 4
|
99 |
+
cfg.MODEL.SEM_SEG_HEAD.FEATURE_ORDER = 'high2low' # ['low2high', 'high2low'] high2low: from high level to low level
|
100 |
+
|
101 |
+
#####################
|
102 |
+
|
103 |
+
# MaskDINO inference config
|
104 |
+
cfg.MODEL.MaskDINO.TEST = CN()
|
105 |
+
cfg.MODEL.MaskDINO.TEST.TEST_FOUCUS_ON_BOX = False
|
106 |
+
cfg.MODEL.MaskDINO.TEST.SEMANTIC_ON = True
|
107 |
+
cfg.MODEL.MaskDINO.TEST.INSTANCE_ON = False
|
108 |
+
cfg.MODEL.MaskDINO.TEST.PANOPTIC_ON = False
|
109 |
+
cfg.MODEL.MaskDINO.TEST.OBJECT_MASK_THRESHOLD = 0.0
|
110 |
+
cfg.MODEL.MaskDINO.TEST.OVERLAP_THRESHOLD = 0.0
|
111 |
+
cfg.MODEL.MaskDINO.TEST.SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE = False
|
112 |
+
cfg.MODEL.MaskDINO.TEST.PANO_TRANSFORM_EVAL = True
|
113 |
+
cfg.MODEL.MaskDINO.TEST.PANO_TEMPERATURE = 0.06
|
114 |
+
# cfg.MODEL.MaskDINO.TEST.EVAL_FLAG = 1
|
115 |
+
|
116 |
+
# Sometimes `backbone.size_divisibility` is set to 0 for some backbone (e.g. ResNet)
|
117 |
+
# you can use this config to override
|
118 |
+
cfg.MODEL.MaskDINO.SIZE_DIVISIBILITY = 32
|
119 |
+
|
120 |
+
# pixel decoder config
|
121 |
+
cfg.MODEL.SEM_SEG_HEAD.MASK_DIM = 256
|
122 |
+
# adding transformer in pixel decoder
|
123 |
+
cfg.MODEL.SEM_SEG_HEAD.TRANSFORMER_ENC_LAYERS = 0
|
124 |
+
# pixel decoder
|
125 |
+
cfg.MODEL.SEM_SEG_HEAD.PIXEL_DECODER_NAME = "MaskDINOEncoder"
|
126 |
+
|
127 |
+
# transformer module
|
128 |
+
cfg.MODEL.MaskDINO.TRANSFORMER_DECODER_NAME = "MaskDINODecoder"
|
129 |
+
|
130 |
+
# LSJ aug
|
131 |
+
cfg.INPUT.IMAGE_SIZE = 1024
|
132 |
+
cfg.INPUT.MIN_SCALE = 0.1
|
133 |
+
cfg.INPUT.MAX_SCALE = 2.0
|
134 |
+
|
135 |
+
# point loss configs
|
136 |
+
# Number of points sampled during training for a mask point head.
|
137 |
+
cfg.MODEL.MaskDINO.TRAIN_NUM_POINTS = 112 * 112
|
138 |
+
# Oversampling parameter for PointRend point sampling during training. Parameter `k` in the
|
139 |
+
# original paper.
|
140 |
+
cfg.MODEL.MaskDINO.OVERSAMPLE_RATIO = 3.0
|
141 |
+
# Importance sampling parameter for PointRend point sampling during training. Parametr `beta` in
|
142 |
+
# the original paper.
|
143 |
+
cfg.MODEL.MaskDINO.IMPORTANCE_SAMPLE_RATIO = 0.75
|
144 |
+
|
145 |
+
|
146 |
+
|
147 |
+
|
148 |
+
cfg.MODEL.DIM_PROJ = 256
|
149 |
+
cfg.MODEL.VISUAL_PROMPT = False
|
150 |
+
cfg.MODEL.TEXT = CN()
|
151 |
+
cfg.MODEL.TEXT.ARCH = 'vlpencoder'
|
152 |
+
cfg.MODEL.TEXT.NAME= 'transformer'
|
153 |
+
cfg.MODEL.TEXT.TOKENIZER= 'clip'
|
154 |
+
cfg.MODEL.TEXT.CONTEXT_LENGTH= 77 # 77
|
155 |
+
cfg.MODEL.TEXT.WIDTH= 512
|
156 |
+
cfg.MODEL.TEXT.HEADS= 8
|
157 |
+
cfg.MODEL.TEXT.LAYERS= 12 # 6
|
158 |
+
cfg.MODEL.TEXT.AUTOGRESSIVE= True
|
159 |
+
|
160 |
+
|
161 |
+
|
162 |
+
cfg.MODEL.LANGUAGE_BACKBONE = CN()
|
163 |
+
cfg.MODEL.LANGUAGE_BACKBONE.USE_CHECKPOINT = False
|
164 |
+
cfg.MODEL.LANGUAGE_BACKBONE.TOKENIZER_TYPE = "bert-base-uncased"
|
165 |
+
cfg.MODEL.LANGUAGE_BACKBONE.MODEL_TYPE = "bert-base-uncased"
|
166 |
+
cfg.MODEL.LANGUAGE_BACKBONE.LANG_DIM = 768
|
167 |
+
cfg.MODEL.LANGUAGE_BACKBONE.MAX_QUERY_LEN = 77 # max length of the tokenized captions.
|
168 |
+
cfg.MODEL.LANGUAGE_BACKBONE.N_LAYERS = 1
|
169 |
+
# cfg.MODEL.LANGUAGE_BACKBONE.UNUSED_TOKEN = 106
|
170 |
+
# cfg.MODEL.LANGUAGE_BACKBONE.MASK_SPECIAL = False
|
171 |
+
cfg.MODEL.LANGUAGE_BACKBONE.PAD_MAX = True
|
172 |
+
|
173 |
+
|
174 |
+
|
175 |
+
|
176 |
+
|
177 |
+
cfg.MODEL.ENCODER = CN()
|
178 |
+
cfg.MODEL.ENCODER.NAME= 'transformer_encoder_fpn'
|
179 |
+
cfg.MODEL.ENCODER.IGNORE_VALUE= 255
|
180 |
+
cfg.MODEL.ENCODER.NUM_CLASSES= 133
|
181 |
+
cfg.MODEL.ENCODER.LOSS_WEIGHT= 1.0
|
182 |
+
cfg.MODEL.ENCODER.CONVS_DIM= 512
|
183 |
+
cfg.MODEL.ENCODER.MASK_DIM= 512
|
184 |
+
cfg.MODEL.ENCODER.NORM= "GN"
|
185 |
+
cfg.MODEL.ENCODER.IN_FEATURES= ["res2", "res3", "res4", "res5"]
|
186 |
+
cfg.MODEL.ENCODER.DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES= ["res3", "res4", "res5"]
|
187 |
+
cfg.MODEL.ENCODER.COMMON_STRIDE= 4
|
188 |
+
cfg.MODEL.ENCODER.TRANSFORMER_ENC_LAYERS= 6
|
189 |
+
|
190 |
+
cfg.MODEL.DECODER = CN()
|
191 |
+
cfg.MODEL.DECODER.TRANSFORMER_IN_FEATURE= "multi_scale_pixel_decoder"
|
192 |
+
cfg.MODEL.DECODER.MASK = True
|
193 |
+
# DETECTION= False
|
194 |
+
# SPATIAL=
|
195 |
+
# ENABLED= True
|
196 |
+
# GROUNDING=
|
197 |
+
# ENABLED= False
|
198 |
+
# MAX_LEN= 5
|
199 |
+
# TEXT_WEIGHT= 2.0
|
200 |
+
# CLASS_WEIGHT= 0.5
|
201 |
+
# VISUAL=
|
202 |
+
# ENABLED= False
|
203 |
+
# AUDIO=
|
204 |
+
# ENABLED= False
|
205 |
+
# OPENIMAGE=
|
206 |
+
# ENABLED= False
|
207 |
+
# NEGATIVE_SAMPLES= 5
|
208 |
+
# GROUNDING=
|
209 |
+
# ENABLED= False
|
210 |
+
# MAX_LEN= 5
|
211 |
+
# CAPTION=
|
212 |
+
# ENABLED= False
|
213 |
+
# PHRASE_PROB= 0.5
|
214 |
+
# SIM_THRES= 0.95
|
215 |
+
cfg.MODEL.DECODER.HIDDEN_DIM= 512
|
216 |
+
cfg.MODEL.DECODER.NUM_OBJECT_QUERIES= 101
|
217 |
+
cfg.MODEL.DECODER.NHEADS= 8
|
218 |
+
cfg.MODEL.DECODER.DROPOUT= 0.0
|
219 |
+
cfg.MODEL.DECODER.DIM_FEEDFORWARD= 2048
|
220 |
+
cfg.MODEL.DECODER.MAX_SPATIAL_LEN= [512, 512, 512, 512]
|
221 |
+
cfg.MODEL.DECODER.PRE_NORM= False
|
222 |
+
cfg.MODEL.DECODER.ENFORCE_INPUT_PROJ= False
|
223 |
+
cfg.MODEL.DECODER.SIZE_DIVISIBILITY= 32
|
224 |
+
cfg.MODEL.DECODER.TRAIN_NUM_POINTS= 12544
|
225 |
+
cfg.MODEL.DECODER.OVERSAMPLE_RATIO= 3.0
|
226 |
+
cfg.MODEL.DECODER.IMPORTANCE_SAMPLE_RATIO= 0.75
|
227 |
+
cfg.MODEL.DECODER.DEC_LAYERS= 10 # 9 decoder layers, add one for the loss on learnable query
|
228 |
+
cfg.MODEL.DECODER.TOP_GROUNDING_LAYERS= 10
|
229 |
+
cfg.MODEL.DECODER.TOP_CAPTION_LAYERS= 10
|
230 |
+
cfg.MODEL.DECODER.TOP_SPATIAL_LAYERS= 10
|
231 |
+
cfg.MODEL.DECODER.TOP_OPENIMAGE_LAYERS= 10
|
232 |
+
# TEST=
|
233 |
+
# SEMANTIC_ON= True
|
234 |
+
# INSTANCE_ON= True
|
235 |
+
# PANOPTIC_ON= True
|
236 |
+
# OVERLAP_THRESHOLD= 0.8
|
237 |
+
# OBJECT_MASK_THRESHOLD= 0.4
|
238 |
+
# SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE= false
|
239 |
+
# DETECTIONS_PER_IMAGE= 100
|
240 |
+
|
241 |
+
cfg.ATTENTION_ARCH = CN()
|
242 |
+
# cfg.ATTENTION_ARCH.VARIABLE={
|
243 |
+
# 'queries': ['object'],
|
244 |
+
# 'tokens': ['grounding', 'spatial', 'visual', 'audio']}
|
245 |
+
|
246 |
+
# SELF_ATTENTION:
|
247 |
+
# queries:
|
248 |
+
# object: ['queries_object', 'tokens_grounding', 'tokens_spatial', 'tokens_visual', 'tokens_audio']
|
249 |
+
# tokens:
|
250 |
+
# grounding: ['queries_object', 'tokens_grounding']
|
251 |
+
# spatial: ['tokens_spatial']
|
252 |
+
# visual: ['tokens_visual']
|
253 |
+
# audio: ['queries_object', 'tokens_audio']
|
254 |
+
# CROSS_ATTENTION:
|
255 |
+
# queries:
|
256 |
+
# object: True
|
257 |
+
# tokens:
|
258 |
+
# grounding: False
|
259 |
+
# spatial: False
|
260 |
+
# visual: False
|
261 |
+
# audio: False
|
262 |
+
# MASKING: ['tokens_spatial', 'tokens_grounding', 'tokens_visual', 'tokens_audio']
|
263 |
+
# DUPLICATION:
|
264 |
+
# queries:
|
265 |
+
# grounding: 'queries_object'
|
266 |
+
# spatial: 'queries_object'
|
267 |
+
# SPATIAL_MEMORIES: 32
|
268 |
+
|
269 |
+
|
270 |
+
|
271 |
+
|
272 |
+
|
273 |
+
|
274 |
+
cfg.SOLVER.OPTIMIZER = "ADAMW"
|
275 |
+
cfg.SOLVER.BACKBONE_MULTIPLIER = 0.1
|
276 |
+
cfg.SOLVER.TEXTENCODER_MULTIPLIER = 1.0
|
277 |
+
cfg.SOLVER.LR_DECAY_RATE = None
|
278 |
+
cfg.SOLVER.LR_DECAY_RATE_NUM_LAYERS = None
|
279 |
+
|
280 |
+
|
281 |
+
## support Swin backbone
|
282 |
+
cfg.MODEL.SWIN = CN()
|
283 |
+
cfg.MODEL.SWIN.PRETRAIN_IMG_SIZE = 224
|
284 |
+
cfg.MODEL.SWIN.PATCH_SIZE = 4
|
285 |
+
cfg.MODEL.SWIN.EMBED_DIM = 96
|
286 |
+
cfg.MODEL.SWIN.DEPTHS = [2, 2, 6, 2]
|
287 |
+
cfg.MODEL.SWIN.NUM_HEADS = [3, 6, 12, 24]
|
288 |
+
cfg.MODEL.SWIN.WINDOW_SIZE = 7
|
289 |
+
cfg.MODEL.SWIN.MLP_RATIO = 4.0
|
290 |
+
cfg.MODEL.SWIN.QKV_BIAS = True
|
291 |
+
cfg.MODEL.SWIN.QK_SCALE = None
|
292 |
+
cfg.MODEL.SWIN.DROP_RATE = 0.0
|
293 |
+
cfg.MODEL.SWIN.ATTN_DROP_RATE = 0.0
|
294 |
+
cfg.MODEL.SWIN.DROP_PATH_RATE = 0.3
|
295 |
+
cfg.MODEL.SWIN.APE = False
|
296 |
+
cfg.MODEL.SWIN.PATCH_NORM = True
|
297 |
+
cfg.MODEL.SWIN.OUT_FEATURES = ["res2", "res3", "res4", "res5"]
|
298 |
+
cfg.MODEL.SWIN.USE_CHECKPOINT = False
|
299 |
+
cfg.MODEL.SWIN.PRETRAINED_WEIGHT = None
|
300 |
+
|
301 |
+
|
302 |
+
# support InterImage backbone
|
303 |
+
cfg.MODEL.INTERNIMAGE = CN() # large as base
|
304 |
+
|
305 |
+
#### large
|
306 |
+
cfg.MODEL.INTERNIMAGE.PRETRAINED_WEIGHT = None
|
307 |
+
cfg.MODEL.INTERNIMAGE.CORE_OP = "DCNv3"
|
308 |
+
cfg.MODEL.INTERNIMAGE.CHANNELS = 160
|
309 |
+
cfg.MODEL.INTERNIMAGE.DEPTHS = [5, 5, 22, 5]
|
310 |
+
cfg.MODEL.INTERNIMAGE.GROUPS =[10, 20, 40, 80]
|
311 |
+
cfg.MODEL.INTERNIMAGE.MLP_RATIO =4.
|
312 |
+
cfg.MODEL.INTERNIMAGE.DROP_PATH_RATE =0.0
|
313 |
+
cfg.MODEL.INTERNIMAGE.NORM_LAYER = "LN"
|
314 |
+
cfg.MODEL.INTERNIMAGE.LAYER_SCALE = 1.0
|
315 |
+
cfg.MODEL.INTERNIMAGE.OFFSET_SCALE = 2.0
|
316 |
+
cfg.MODEL.INTERNIMAGE.POST_NORM = True
|
317 |
+
cfg.MODEL.INTERNIMAGE.WITH_CP = False
|
318 |
+
cfg.MODEL.INTERNIMAGE.OUT_IINDICES = (0, 1, 2, 3)
|
319 |
+
cfg.MODEL.INTERNIMAGE.DW_KERNEL_SIZE = None
|
320 |
+
cfg.MODEL.INTERNIMAGE.RES_POST_NORM = False
|
321 |
+
cfg.MODEL.INTERNIMAGE.LEVEL2_POST_NORM = False
|
322 |
+
cfg.MODEL.INTERNIMAGE.LEVEL2_POST_NORM_BLOCK_IDS = None
|
323 |
+
cfg.MODEL.INTERNIMAGE.CENTER_FEATURE_SCALE = False
|
324 |
+
|
325 |
+
### huge
|
326 |
+
# cfg.MODEL.INTERNIMAGE.PRETRAINED_WEIGHT = None
|
327 |
+
# cfg.MODEL.INTERNIMAGE.CORE_OP = "DCNv3"
|
328 |
+
# cfg.MODEL.INTERNIMAGE.CHANNELS = 320
|
329 |
+
# cfg.MODEL.INTERNIMAGE.DEPTHS = [6, 6, 32, 6]
|
330 |
+
# cfg.MODEL.INTERNIMAGE.GROUPS = [10, 20, 40, 80]
|
331 |
+
# cfg.MODEL.INTERNIMAGE.MLP_RATIO =4.
|
332 |
+
# cfg.MODEL.INTERNIMAGE.DROP_PATH_RATE = 0.5
|
333 |
+
# cfg.MODEL.INTERNIMAGE.NORM_LAYER = "LN"
|
334 |
+
# cfg.MODEL.INTERNIMAGE.LAYER_SCALE = None
|
335 |
+
# cfg.MODEL.INTERNIMAGE.OFFSET_SCALE = 1.0
|
336 |
+
# cfg.MODEL.INTERNIMAGE.POST_NORM = False
|
337 |
+
# cfg.MODEL.INTERNIMAGE.WITH_CP = False
|
338 |
+
# cfg.MODEL.INTERNIMAGE.OUT_IINDICES = (0, 1, 2, 3)
|
339 |
+
|
340 |
+
# cfg.MODEL.INTERNIMAGE.DW_KERNEL_SIZE = 5
|
341 |
+
# cfg.MODEL.INTERNIMAGE.RES_POST_NORM = True
|
342 |
+
# cfg.MODEL.INTERNIMAGE.LEVEL2_POST_NORM = True
|
343 |
+
# cfg.MODEL.INTERNIMAGE.LEVEL2_POST_NORM_BLOCK_IDS = [5, 11, 17, 23, 29]
|
344 |
+
# cfg.MODEL.INTERNIMAGE.CENTER_FEATURE_SCALE = True
|
345 |
+
|
346 |
+
|
347 |
+
# support EVA02 backbone
|
348 |
+
cfg.MODEL.EVA02 = CN() # large as base
|
349 |
+
|
350 |
+
#### large
|
351 |
+
cfg.MODEL.EVA02.PRETRAINED_WEIGHT = None
|
352 |
+
cfg.MODEL.EVA02.IMAGE_SIZE = 1536
|
353 |
+
cfg.MODEL.EVA02.PATCH_SIZE = 16
|
354 |
+
cfg.MODEL.EVA02.WINDOW_SIZE = 16
|
355 |
+
cfg.MODEL.EVA02.DMBED_DIM =1024
|
356 |
+
cfg.MODEL.EVA02.DEPTH = 24
|
357 |
+
cfg.MODEL.EVA02.NUM_HEADS = 16
|
358 |
+
cfg.MODEL.EVA02.MLP_RATIO = 4*2/3
|
359 |
+
cfg.MODEL.EVA02.DROP_PATH_RATE = 0.3
|
360 |
+
cfg.MODEL.EVA02.CHECKPOINT = True
|
361 |
+
cfg.MODEL.EVA02.WINDOW_BLOCK_INDEXES = [0, 1, 3, 4, 6, 7, 9, 10, 12, 13, 15, 16, 18, 19, 21, 22]
|
362 |
+
|
363 |
+
|
364 |
+
|
365 |
+
# support EVA01 backbone
|
366 |
+
cfg.MODEL.EVA01 = CN() # large as base
|
367 |
+
|
368 |
+
#### large
|
369 |
+
cfg.MODEL.EVA01.PRETRAINED_WEIGHT = None
|
370 |
+
|
371 |
+
cfg.MODEL.EVA01.BEIT_LIKE_QKV_BIAS = True
|
372 |
+
cfg.MODEL.EVA01.BEIT_LIKE_GAMMA = False
|
373 |
+
cfg.MODEL.EVA01.FREEZE_PATH_EMBED = True
|
374 |
+
|
375 |
+
cfg.MODEL.EVA01.IMAGE_SIZE = 1280 # only for correct dim in pos embed
|
376 |
+
cfg.MODEL.EVA01.PATCH_SIZE = 16
|
377 |
+
cfg.MODEL.EVA01.WINDOW_SIZE = 16
|
378 |
+
cfg.MODEL.EVA01.DMBED_DIM = 1408
|
379 |
+
cfg.MODEL.EVA01.DEPTH = 40
|
380 |
+
cfg.MODEL.EVA01.NUM_HEADS = 16
|
381 |
+
cfg.MODEL.EVA01.MLP_RATIO = 6144 / 1408
|
382 |
+
cfg.MODEL.EVA01.DROP_PATH_RATE = 0.6
|
383 |
+
cfg.MODEL.EVA01.WINDOW_BLOCK_INDEXES = [0, 1, 2, 4, 5, 6, 8, 9, 10, 12, 13, 14, 16, 17, 18, 20, 21, 22, 24, 25, 26, 28, 29, 30, 32, 33, 34, 36, 37, 38]
|
384 |
+
|
385 |
+
|
386 |
+
|
387 |
+
|
GLEE/glee/config_deeplab.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
3 |
+
|
4 |
+
|
5 |
+
def add_deeplab_config(cfg):
|
6 |
+
"""
|
7 |
+
Add config for DeepLab.
|
8 |
+
"""
|
9 |
+
# We retry random cropping until no single category in semantic segmentation GT occupies more
|
10 |
+
# than `SINGLE_CATEGORY_MAX_AREA` part of the crop.
|
11 |
+
cfg.INPUT.CROP.SINGLE_CATEGORY_MAX_AREA = 1.0
|
12 |
+
# Used for `poly` learning rate schedule.
|
13 |
+
cfg.SOLVER.POLY_LR_POWER = 0.9
|
14 |
+
cfg.SOLVER.POLY_LR_CONSTANT_ENDING = 0.0
|
15 |
+
# Loss type, choose from `cross_entropy`, `hard_pixel_mining`.
|
16 |
+
cfg.MODEL.SEM_SEG_HEAD.LOSS_TYPE = "hard_pixel_mining"
|
17 |
+
# DeepLab settings
|
18 |
+
cfg.MODEL.SEM_SEG_HEAD.PROJECT_FEATURES = ["res2"]
|
19 |
+
cfg.MODEL.SEM_SEG_HEAD.PROJECT_CHANNELS = [48]
|
20 |
+
cfg.MODEL.SEM_SEG_HEAD.ASPP_CHANNELS = 256
|
21 |
+
cfg.MODEL.SEM_SEG_HEAD.ASPP_DILATIONS = [6, 12, 18]
|
22 |
+
cfg.MODEL.SEM_SEG_HEAD.ASPP_DROPOUT = 0.1
|
23 |
+
cfg.MODEL.SEM_SEG_HEAD.USE_DEPTHWISE_SEPARABLE_CONV = False
|
24 |
+
# Backbone new configs
|
25 |
+
cfg.MODEL.RESNETS.RES4_DILATION = 1
|
26 |
+
cfg.MODEL.RESNETS.RES5_MULTI_GRID = [1, 2, 4]
|
27 |
+
# ResNet stem type from: `basic`, `deeplab`
|
28 |
+
cfg.MODEL.RESNETS.STEM_TYPE = "deeplab"
|
GLEE/glee/models/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
GLEE/glee/models/glee_model.py
ADDED
@@ -0,0 +1,296 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
2 |
+
"""
|
3 |
+
"""
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from torch import nn
|
7 |
+
# from ..backbone import build_backbone, Backbone
|
8 |
+
# from ..body.encoder import build_encoder
|
9 |
+
# from ..body.decoder import build_decoder
|
10 |
+
|
11 |
+
from detectron2.modeling import build_backbone
|
12 |
+
|
13 |
+
from .pixel_decoder.maskdino_encoder import build_pixel_decoder
|
14 |
+
from .transformer_decoder.maskdino_decoder import build_transformer_decoder
|
15 |
+
|
16 |
+
import random
|
17 |
+
from transformers import AutoTokenizer
|
18 |
+
from collections import OrderedDict
|
19 |
+
from ..modules.point_features import point_sample
|
20 |
+
from timm.models.layers import trunc_normal_
|
21 |
+
from transformers import CLIPTokenizer,CLIPTextModel
|
22 |
+
from .vos_utils import masks_to_boxes, FeatureFuser
|
23 |
+
import numpy as np
|
24 |
+
import math
|
25 |
+
|
26 |
+
|
27 |
+
def rand_sample(x, max_len):
|
28 |
+
if x.shape[1] <= max_len:
|
29 |
+
return x
|
30 |
+
else:
|
31 |
+
rand_idx = torch.randperm(x.shape[1])[:max_len]
|
32 |
+
return x[:,rand_idx]
|
33 |
+
|
34 |
+
|
35 |
+
def agg_lang_feat(features, mask, pool_type="average"):
|
36 |
+
"""average pooling of language features"""
|
37 |
+
# feat: (bs, seq_len, C)
|
38 |
+
# mask: (bs, seq_len)
|
39 |
+
if pool_type == "average":
|
40 |
+
embedded = features * mask.unsqueeze(-1).float() # use mask to zero out invalid token features
|
41 |
+
aggregate = embedded.sum(1) / (mask.sum(-1).unsqueeze(-1).float())
|
42 |
+
elif pool_type == "max":
|
43 |
+
out = []
|
44 |
+
for i in range(len(features)):
|
45 |
+
pool_feat, _ = torch.max(features[i][mask[i]], 0) # (L, C) -> (C, )
|
46 |
+
out.append(pool_feat)
|
47 |
+
aggregate = torch.stack(out, dim=0) # (bs, C)
|
48 |
+
else:
|
49 |
+
raise ValueError("pool_type should be average or max")
|
50 |
+
return aggregate
|
51 |
+
|
52 |
+
class GLEE_Model(nn.Module):
|
53 |
+
"""
|
54 |
+
Main class for mask classification semantic segmentation architectures.
|
55 |
+
"""
|
56 |
+
def __init__(self, cfg, matcher, device, video_info, contras_mean):
|
57 |
+
super().__init__()
|
58 |
+
self.cfg = cfg
|
59 |
+
self.matcher = matcher
|
60 |
+
self.backbone = build_backbone(cfg)
|
61 |
+
output_channels = [v for k,v in self.backbone._out_feature_channels.items()]
|
62 |
+
self.sot_fuser = FeatureFuser(output_channels[-3:], 256)
|
63 |
+
|
64 |
+
|
65 |
+
self.tokenizer = CLIPTokenizer.from_pretrained('GLEE/clip_vit_base_patch32')
|
66 |
+
self.tokenizer.add_special_tokens({'cls_token': self.tokenizer.eos_token})
|
67 |
+
self.text_encoder = CLIPTextModel.from_pretrained('GLEE/clip_vit_base_patch32')
|
68 |
+
# self.text_encoder_teacher = CLIPTextModel.from_pretrained('GLEE/clip_vit_base_patch32')
|
69 |
+
self.lang_encoder = None
|
70 |
+
# for p in self.text_encoder_teacher.parameters():
|
71 |
+
# p.requires_grad = False
|
72 |
+
self.lang_projection = nn.Parameter(torch.rand(cfg.MODEL.LANGUAGE_BACKBONE.LANG_DIM, cfg.MODEL.DIM_PROJ))
|
73 |
+
self.text_encode_type = 'clip_teacher'
|
74 |
+
|
75 |
+
# self.lang_encoder = None
|
76 |
+
self.pixel_decoder = build_pixel_decoder(cfg, self.backbone.output_shape())
|
77 |
+
transformer_predictor_in_channels = cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM
|
78 |
+
self.predictor = build_transformer_decoder(cfg, transformer_predictor_in_channels, lang_encoder = self.lang_encoder, mask_classification=True,)
|
79 |
+
self.to(device)
|
80 |
+
|
81 |
+
self.video_info = video_info
|
82 |
+
self.contras_mean = contras_mean
|
83 |
+
|
84 |
+
self.track_loss_version = cfg.MODEL.TRACK_VERSION
|
85 |
+
|
86 |
+
self.no_mask_tasks = ['obj365', 'obj365_clip','openimage', 'openimage_clip', 'vg', 'grit', 'bdd_det', 'bdd_track_box']
|
87 |
+
|
88 |
+
|
89 |
+
# for visual prompt
|
90 |
+
hidden_dim = 256
|
91 |
+
self.max_spatial_len = [512,512,512,512]
|
92 |
+
self.mask_sptial_embed = nn.ParameterList([nn.Parameter(torch.empty(hidden_dim, hidden_dim)) for x in range(4)])
|
93 |
+
trunc_normal_(self.mask_sptial_embed[0], std=.02)
|
94 |
+
trunc_normal_(self.mask_sptial_embed[1], std=.02)
|
95 |
+
trunc_normal_(self.mask_sptial_embed[2], std=.02)
|
96 |
+
trunc_normal_(self.mask_sptial_embed[3], std=.02)
|
97 |
+
# learnable positive negative indicator
|
98 |
+
self.pn_indicator = nn.Embedding(2, hidden_dim)
|
99 |
+
|
100 |
+
@property
|
101 |
+
def device(self):
|
102 |
+
return self.pixel_mean.device
|
103 |
+
|
104 |
+
def forward(self, images, prompts, task, targets=None, batch_name_list=None, is_train = True, visual_prompt_type='scribble'):
|
105 |
+
extra = {}
|
106 |
+
# dist_loss = None
|
107 |
+
early_semantic = None
|
108 |
+
|
109 |
+
if self.text_encode_type == "clip_teacher":
|
110 |
+
if task not in ['grounding','rvos']:
|
111 |
+
assert batch_name_list
|
112 |
+
calsses_name_list = batch_name_list
|
113 |
+
tokenized = self.tokenizer.batch_encode_plus(calsses_name_list,
|
114 |
+
max_length=self.cfg.MODEL.LANGUAGE_BACKBONE.MAX_QUERY_LEN, # 256
|
115 |
+
padding='max_length' if self.cfg.MODEL.LANGUAGE_BACKBONE.PAD_MAX else "longest", # max_length
|
116 |
+
return_special_tokens_mask=True,
|
117 |
+
return_tensors='pt',
|
118 |
+
truncation=True).to(images.device)
|
119 |
+
texts = (tokenized['input_ids'], tokenized['attention_mask'])
|
120 |
+
token_x = self.text_encoder(*texts)['last_hidden_state']
|
121 |
+
|
122 |
+
valid_mask = tokenized['attention_mask'].bool()
|
123 |
+
# token_x_teacher = self.text_encoder_teacher(*texts)['last_hidden_state']
|
124 |
+
# if is_train:
|
125 |
+
# dist_loss = F.mse_loss(token_x[valid_mask], token_x_teacher[valid_mask] )
|
126 |
+
# F.l2_loss(token_x[valid_mask], token_x_teacher[valid_mask] )
|
127 |
+
token_x = token_x @ self.lang_projection
|
128 |
+
lang_feat_pool = agg_lang_feat(token_x, tokenized['attention_mask'], pool_type="average") # (bs, 768)
|
129 |
+
extra['class_embeddings'] = lang_feat_pool
|
130 |
+
if True: # early_fusion
|
131 |
+
gather_all_classtoken = token_x.flatten(0,1)[tokenized['attention_mask'].flatten(0,1)>0]
|
132 |
+
gather_all_classtoken = gather_all_classtoken.unsqueeze(0).repeat(len(images),1,1) #[bs,L,C]
|
133 |
+
gather_all_classtoken_mask = torch.ones_like(gather_all_classtoken[:,:,0])>0 #[bs,L]
|
134 |
+
early_semantic = {"hidden":gather_all_classtoken.float(),"masks":gather_all_classtoken_mask}
|
135 |
+
|
136 |
+
|
137 |
+
if 'grounding' in prompts:
|
138 |
+
|
139 |
+
if self.text_encode_type == 'clip_frozen' or self.text_encode_type == 'clip_teacher':
|
140 |
+
|
141 |
+
tokens = self.tokenizer(
|
142 |
+
prompts['grounding'], padding='max_length', truncation=True, max_length=self.cfg.MODEL.LANGUAGE_BACKBONE.MAX_QUERY_LEN, return_tensors='pt'
|
143 |
+
)
|
144 |
+
tokens = {key: value.to(images.device) for key, value in tokens.items()}
|
145 |
+
|
146 |
+
texts = (tokens['input_ids'], tokens['attention_mask'])
|
147 |
+
x = self.text_encoder(*texts)
|
148 |
+
token_x = x['last_hidden_state']
|
149 |
+
token_x = token_x @ self.lang_projection
|
150 |
+
|
151 |
+
extra['grounding_tokens'] = token_x.permute(1,0,2) #[len,bz,C]
|
152 |
+
|
153 |
+
non_zero_query_mask = tokens['attention_mask']
|
154 |
+
lang_feat_pool = agg_lang_feat(token_x, non_zero_query_mask, pool_type="average").unsqueeze(1) # (bs, 1, 768)
|
155 |
+
|
156 |
+
dist_loss = (lang_feat_pool*0).sum()
|
157 |
+
|
158 |
+
extra['grounding_nonzero_mask'] = ~non_zero_query_mask.bool() # [bz,len]
|
159 |
+
extra['grounding_class'] = lang_feat_pool.squeeze(1) #[bz,C
|
160 |
+
# gather_all_classtoken = token_x.flatten(0,1)[tokenized['attention_mask'].flatten(0,1)>0]
|
161 |
+
# gather_all_classtoken = gather_all_classtoken.unsqueeze(0).repeat(len(images),1,1) #[bs,L,C]
|
162 |
+
# gather_all_classtoken_mask = torch.ones_like(gather_all_classtoken[:,:,0])>0 #[bs,L]
|
163 |
+
# early_semantic = {"hidden":gather_all_classtoken.float(),"masks":gather_all_classtoken_mask}
|
164 |
+
early_semantic = {"hidden":token_x.float(),"masks":tokens['attention_mask']>0}
|
165 |
+
|
166 |
+
|
167 |
+
if isinstance(images,torch.Tensor):
|
168 |
+
features = self.backbone(images)
|
169 |
+
else:
|
170 |
+
features = self.backbone(images.tensor)
|
171 |
+
|
172 |
+
|
173 |
+
|
174 |
+
|
175 |
+
if 'spatial' in prompts:
|
176 |
+
## setp 1,2,3
|
177 |
+
key_images = [ images ] #bz*[1,3,H,W]
|
178 |
+
key_promptmasks = [m.unsqueeze(0) for m in prompts['spatial']] #bz*[1,1,H,W]
|
179 |
+
|
180 |
+
prompt_mode = visual_prompt_type
|
181 |
+
ref_feats, ref_masks = self.get_template(key_images, key_promptmasks, prompt_mode)
|
182 |
+
early_fusion = {"hidden":ref_feats,"masks":ref_masks}
|
183 |
+
if early_semantic is None:
|
184 |
+
early_semantic = early_fusion
|
185 |
+
else:
|
186 |
+
early_semantic["hidden"] = torch.cat([early_semantic["hidden"],early_fusion["hidden"]],dim=1)
|
187 |
+
early_semantic["masks"] = torch.cat([early_semantic["masks"],early_fusion["masks"]],dim=1)
|
188 |
+
|
189 |
+
|
190 |
+
# bz = len(images)//2
|
191 |
+
mask_features, _, multi_scale_features, zero_loss = self.pixel_decoder.forward_features(features, masks=None, early_fusion = early_semantic)
|
192 |
+
if 'spatial' in prompts:
|
193 |
+
pos_masks = prompts['spatial']
|
194 |
+
# neg_masks = [~p for p in prompts['spatial']]
|
195 |
+
neg_masks = [p&False for p in prompts['spatial']]
|
196 |
+
|
197 |
+
extra.update({'spatial_query_pos_mask': pos_masks, 'spatial_query_neg_mask': neg_masks})
|
198 |
+
|
199 |
+
|
200 |
+
_,h,w = extra['spatial_query_pos_mask'][0].shape
|
201 |
+
divisor = torch.tensor([h,w], device=mask_features.device)[None,]
|
202 |
+
# Get mean pos spatial query
|
203 |
+
non_zero_pos_point = [rand_sample((m.nonzero()[:,1:]/divisor).t(), self.max_spatial_len[-1]).t() for m in extra['spatial_query_pos_mask']]
|
204 |
+
non_zero_pos_point = nn.utils.rnn.pad_sequence(non_zero_pos_point, padding_value=-1).permute(1,0,2)
|
205 |
+
non_zero_pos_mask = (non_zero_pos_point.sum(dim=-1) < 0)
|
206 |
+
spatial_query_pos = point_sample(mask_features, non_zero_pos_point.flip(dims=(2,)).type(mask_features.dtype), align_corners=True) #[(N, C, P)
|
207 |
+
spatial_query_pos = torch.stack([x[m].mean(dim=0, keepdim=True) for x, m in zip(spatial_query_pos.transpose(1,2), ~non_zero_pos_mask)]).transpose(0,1).nan_to_num() # [1,bz,C]
|
208 |
+
# Get mean neg spatial query
|
209 |
+
non_zero_neg_point = [rand_sample((m.nonzero()[:,1:]/divisor).t(), self.max_spatial_len[-1]).t() for m in extra['spatial_query_neg_mask']]
|
210 |
+
non_zero_neg_point = nn.utils.rnn.pad_sequence(non_zero_neg_point, padding_value=-1).permute(1,0,2)
|
211 |
+
non_zero_neg_mask = (non_zero_neg_point.sum(dim=-1) < 0)
|
212 |
+
spatial_query_neg = point_sample(mask_features, non_zero_neg_point.flip(dims=(2,)).type(mask_features.dtype), align_corners=True)
|
213 |
+
spatial_query_neg = torch.stack([x[m].mean(dim=0, keepdim=True) for x, m in zip(spatial_query_neg.transpose(1,2), ~non_zero_neg_mask)]).transpose(0,1).nan_to_num()
|
214 |
+
|
215 |
+
# Get layerwise spatial query
|
216 |
+
src_spatial_queries = []
|
217 |
+
src_spatial_maskings = []
|
218 |
+
for i in range(len(multi_scale_features)):
|
219 |
+
bs,dc,h,w = multi_scale_features[i].shape
|
220 |
+
# src_mask_features = multi_scale_features[i].view(h,w,bs,dc)
|
221 |
+
src_mask_features = multi_scale_features[i].permute(2,3,0,1)
|
222 |
+
src_mask_features = src_mask_features @ self.mask_sptial_embed[i]
|
223 |
+
|
224 |
+
non_zero_query_point_pos = [rand_sample((m.nonzero()[:,1:]/divisor).t(), self.max_spatial_len[i]).t() for m in extra['spatial_query_pos_mask']]
|
225 |
+
non_zero_query_point_neg = [rand_sample((m.nonzero()[:,1:]/divisor).t(), self.max_spatial_len[i]).t() for m in extra['spatial_query_neg_mask']]
|
226 |
+
non_zero_query_point = [torch.cat([x,y], dim=0) for x,y in zip(non_zero_query_point_pos, non_zero_query_point_neg)]
|
227 |
+
pos_neg_indicator = [torch.cat([torch.ones(x.shape[0], device=x.device), -torch.ones(y.shape[0], device=y.device)]) for x,y in zip(non_zero_query_point_pos, non_zero_query_point_neg)]
|
228 |
+
pos_neg_indicator = nn.utils.rnn.pad_sequence(pos_neg_indicator, padding_value=0)
|
229 |
+
non_zero_query_point = nn.utils.rnn.pad_sequence(non_zero_query_point, padding_value=-1).permute(1,0,2)
|
230 |
+
non_zero_query_mask = (non_zero_query_point.sum(dim=-1) < 0)
|
231 |
+
non_zero_query_point[non_zero_query_mask] = 0
|
232 |
+
|
233 |
+
spatial_tokens = point_sample(src_mask_features.permute(2,3,0,1), non_zero_query_point.flip(dims=(2,)).type(src_mask_features.dtype), align_corners=True).permute(2,0,1)
|
234 |
+
spatial_tokens[pos_neg_indicator==1] += self.pn_indicator.weight[0:1]
|
235 |
+
spatial_tokens[pos_neg_indicator==-1] += self.pn_indicator.weight[1:2]
|
236 |
+
|
237 |
+
src_spatial_queries += [spatial_tokens]
|
238 |
+
src_spatial_maskings += [non_zero_query_mask]
|
239 |
+
|
240 |
+
extra['visual_prompt_tokens'] = src_spatial_queries #[len,bz,C]
|
241 |
+
extra['visual_prompt_nonzero_mask'] = src_spatial_maskings # [bz,len]
|
242 |
+
|
243 |
+
|
244 |
+
outputs = self.predictor(multi_scale_features, mask_features, extra=extra, task=task, masks=None, targets=targets)
|
245 |
+
return outputs
|
246 |
+
|
247 |
+
|
248 |
+
|
249 |
+
|
250 |
+
|
251 |
+
|
252 |
+
def get_template(self, imgs, pad_masks, prompt_mode='scribble'):
|
253 |
+
"""img: (N, 3, H, W), mask: (N, 1, H, W), bbox: (1, 4)"""
|
254 |
+
"""get 4-channel template"""
|
255 |
+
|
256 |
+
croped_img_with_mask = []
|
257 |
+
|
258 |
+
for image_i, mask_i in zip( imgs, pad_masks):
|
259 |
+
|
260 |
+
if prompt_mode in ['scribble','point']:
|
261 |
+
image_with_mask = image_i + mask_i.to(image_i)
|
262 |
+
else:
|
263 |
+
image_with_mask = image_i
|
264 |
+
|
265 |
+
# image_with_mask = torch.cat([image_i,mask_i.to(image_i)],dim=1) #[1,3,H,W]
|
266 |
+
box_i = masks_to_boxes(mask_i[0]) #[xyxy]
|
267 |
+
box_i[:, 2:] = box_i[:, 2:] - box_i[:, :2] #xywh
|
268 |
+
|
269 |
+
|
270 |
+
x, y, w, h = box_i[0].long().tolist()
|
271 |
+
|
272 |
+
self.search_area_factor=2
|
273 |
+
|
274 |
+
crop_sz = math.ceil(math.sqrt(w * h) * self.search_area_factor)
|
275 |
+
x1 = max(0,round(x + 0.5 * w - crop_sz * 0.5))
|
276 |
+
x2 = x1 + crop_sz
|
277 |
+
y1 = max(0,round(y + 0.5 * h - crop_sz * 0.5))
|
278 |
+
y2 = y1 + crop_sz
|
279 |
+
|
280 |
+
im_crop = image_with_mask[:, :, y1:y2, x1:x2]
|
281 |
+
# resize
|
282 |
+
if im_crop.shape[-1] ==0 or im_crop.shape[-2] ==0 :
|
283 |
+
im_crop = image_with_mask
|
284 |
+
im_crop = F.interpolate(im_crop, (256,256), mode='bilinear', align_corners=False)
|
285 |
+
croped_img_with_mask.append(im_crop)
|
286 |
+
croped_img_with_mask = torch.cat(croped_img_with_mask,dim=0) #[bz,3,256,256]
|
287 |
+
with torch.no_grad():
|
288 |
+
ref_srcs = self.backbone(croped_img_with_mask.contiguous())
|
289 |
+
ref_srcs = [v for k,v in ref_srcs.items()]
|
290 |
+
ref_feats = self.sot_fuser(ref_srcs[1:]).float() #[bz,256,32,32]
|
291 |
+
|
292 |
+
ref_feats = ref_feats.flatten(-2).permute(0, 2, 1) # (bs, L, C)
|
293 |
+
ref_masks = torch.ones_like(ref_feats[:,:,0])>0 #[bs,L]
|
294 |
+
|
295 |
+
return ref_feats, ref_masks
|
296 |
+
|
GLEE/glee/models/pixel_decoder/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# Copyright (c) IDEA, Inc. and its affiliates.
|
GLEE/glee/models/pixel_decoder/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (184 Bytes). View file
|
|
GLEE/glee/models/pixel_decoder/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (154 Bytes). View file
|
|
GLEE/glee/models/pixel_decoder/__pycache__/early_fusion.cpython-38.pyc
ADDED
Binary file (6.36 kB). View file
|
|
GLEE/glee/models/pixel_decoder/__pycache__/early_fusion.cpython-39.pyc
ADDED
Binary file (6.3 kB). View file
|
|
GLEE/glee/models/pixel_decoder/__pycache__/maskdino_encoder.cpython-38.pyc
ADDED
Binary file (15.4 kB). View file
|
|
GLEE/glee/models/pixel_decoder/__pycache__/maskdino_encoder.cpython-39.pyc
ADDED
Binary file (15.3 kB). View file
|
|
GLEE/glee/models/pixel_decoder/__pycache__/position_encoding.cpython-38.pyc
ADDED
Binary file (2.64 kB). View file
|
|
GLEE/glee/models/pixel_decoder/__pycache__/position_encoding.cpython-39.pyc
ADDED
Binary file (2.6 kB). View file
|
|
GLEE/glee/models/pixel_decoder/early_fusion.py
ADDED
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from torch import nn
|
4 |
+
from timm.models.layers import DropPath
|
5 |
+
|
6 |
+
|
7 |
+
|
8 |
+
|
9 |
+
class VLFuse(torch.nn.Module):
|
10 |
+
"""
|
11 |
+
Early Fusion Module
|
12 |
+
"""
|
13 |
+
|
14 |
+
def __init__(self, ):
|
15 |
+
super(VLFuse, self).__init__()
|
16 |
+
self.init_configs()
|
17 |
+
|
18 |
+
# early fusion module
|
19 |
+
# bi-direction (text->image, image->text)
|
20 |
+
self.b_attn = BiAttentionBlockForCheckpoint(v_dim=self.img_dim, # 256
|
21 |
+
l_dim=self.lang_dim, # 768
|
22 |
+
embed_dim=self.embed_dim, # 2048
|
23 |
+
num_heads=self.n_head, # 8
|
24 |
+
dropout=0.1,
|
25 |
+
drop_path=.0,
|
26 |
+
init_values=1.0 / 6,
|
27 |
+
)
|
28 |
+
def init_configs(self, ):
|
29 |
+
# common params
|
30 |
+
self.img_dim = 256
|
31 |
+
|
32 |
+
self.max_query_len = 256
|
33 |
+
self.n_layers =1
|
34 |
+
|
35 |
+
# mha params
|
36 |
+
self.n_head = 8
|
37 |
+
self.embed_dim = 2048 # 2048 by default
|
38 |
+
|
39 |
+
self.lang_dim = 256
|
40 |
+
|
41 |
+
def forward(self, x, task=None):
|
42 |
+
visual_features = x["visual"]
|
43 |
+
language_dict_features = x["lang"]
|
44 |
+
|
45 |
+
fused_visual_features, language_features = self.b_attn(
|
46 |
+
visual_features, language_dict_features['hidden'], language_dict_features['masks'], task)
|
47 |
+
|
48 |
+
language_dict_features['hidden'] = language_features
|
49 |
+
fused_language_dict_features = language_dict_features
|
50 |
+
|
51 |
+
features_dict = {"visual": fused_visual_features,
|
52 |
+
"lang": fused_language_dict_features}
|
53 |
+
|
54 |
+
return features_dict
|
55 |
+
|
56 |
+
|
57 |
+
class BiMultiHeadAttention(nn.Module):
|
58 |
+
def __init__(self, v_dim, l_dim, embed_dim, num_heads, dropout=0.1):
|
59 |
+
super(BiMultiHeadAttention, self).__init__()
|
60 |
+
|
61 |
+
self.embed_dim = embed_dim
|
62 |
+
self.num_heads = num_heads
|
63 |
+
self.head_dim = embed_dim // num_heads
|
64 |
+
self.v_dim = v_dim
|
65 |
+
self.l_dim = l_dim
|
66 |
+
|
67 |
+
assert (
|
68 |
+
self.head_dim * self.num_heads == self.embed_dim
|
69 |
+
), f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})."
|
70 |
+
self.scale = self.head_dim ** (-0.5)
|
71 |
+
self.dropout = dropout
|
72 |
+
|
73 |
+
self.v_proj = nn.Linear(self.v_dim, self.embed_dim)
|
74 |
+
self.l_proj = nn.Linear(self.l_dim, self.embed_dim)
|
75 |
+
self.values_v_proj = nn.Linear(self.v_dim, self.embed_dim)
|
76 |
+
self.values_l_proj = nn.Linear(self.l_dim, self.embed_dim)
|
77 |
+
|
78 |
+
self.out_v_proj = nn.Linear(self.embed_dim, self.v_dim)
|
79 |
+
self.out_l_proj = nn.Linear(self.embed_dim, self.l_dim)
|
80 |
+
|
81 |
+
self.stable_softmax_2d = False
|
82 |
+
self.clamp_min_for_underflow = True
|
83 |
+
self.clamp_max_for_overflow = True
|
84 |
+
|
85 |
+
self._reset_parameters()
|
86 |
+
|
87 |
+
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
88 |
+
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
89 |
+
|
90 |
+
def _reset_parameters(self):
|
91 |
+
nn.init.xavier_uniform_(self.v_proj.weight)
|
92 |
+
self.v_proj.bias.data.fill_(0)
|
93 |
+
nn.init.xavier_uniform_(self.l_proj.weight)
|
94 |
+
self.l_proj.bias.data.fill_(0)
|
95 |
+
nn.init.xavier_uniform_(self.values_v_proj.weight)
|
96 |
+
self.values_v_proj.bias.data.fill_(0)
|
97 |
+
nn.init.xavier_uniform_(self.values_l_proj.weight)
|
98 |
+
self.values_l_proj.bias.data.fill_(0)
|
99 |
+
nn.init.xavier_uniform_(self.out_v_proj.weight)
|
100 |
+
self.out_v_proj.bias.data.fill_(0)
|
101 |
+
nn.init.xavier_uniform_(self.out_l_proj.weight)
|
102 |
+
self.out_l_proj.bias.data.fill_(0)
|
103 |
+
|
104 |
+
def forward(self, v, l, attention_mask_l=None):
|
105 |
+
bsz, tgt_len, embed_dim = v.size()
|
106 |
+
|
107 |
+
query_states = self.v_proj(v) * self.scale
|
108 |
+
key_states = self._shape(self.l_proj(l), -1, bsz)
|
109 |
+
value_v_states = self._shape(self.values_v_proj(v), -1, bsz)
|
110 |
+
value_l_states = self._shape(self.values_l_proj(l), -1, bsz)
|
111 |
+
|
112 |
+
proj_shape = (bsz * self.num_heads, -1, self.head_dim) # (bs * 8, -1, embed_dim//8)
|
113 |
+
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) # (bs * 8, seq_len_img, embed_dim//8)
|
114 |
+
key_states = key_states.view(*proj_shape) # (bs * 8, seq_len_text, embed_dim//8)
|
115 |
+
value_v_states = value_v_states.view(*proj_shape)
|
116 |
+
value_l_states = value_l_states.view(*proj_shape)
|
117 |
+
|
118 |
+
src_len = key_states.size(1)
|
119 |
+
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) # (bs * 8, seq_len_img, seq_len_text)
|
120 |
+
|
121 |
+
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
|
122 |
+
raise ValueError(
|
123 |
+
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
|
124 |
+
)
|
125 |
+
|
126 |
+
# attn_weights_l = nn.functional.softmax(attn_weights.transpose(1, 2), dim=-1)
|
127 |
+
|
128 |
+
if self.stable_softmax_2d:
|
129 |
+
attn_weights = attn_weights - attn_weights.max()
|
130 |
+
|
131 |
+
if self.clamp_min_for_underflow:
|
132 |
+
attn_weights = torch.clamp(attn_weights, min=-50000) # Do not increase -50000, data type half has quite limited range
|
133 |
+
if self.clamp_max_for_overflow:
|
134 |
+
attn_weights = torch.clamp(attn_weights, max=50000) # Do not increase 50000, data type half has quite limited range
|
135 |
+
|
136 |
+
attn_weights_T = attn_weights.transpose(1, 2)
|
137 |
+
attn_weights_l = (attn_weights_T - torch.max(attn_weights_T, dim=-1, keepdim=True)[
|
138 |
+
0])
|
139 |
+
if self.clamp_min_for_underflow:
|
140 |
+
attn_weights_l = torch.clamp(attn_weights_l, min=-50000) # Do not increase -50000, data type half has quite limited range
|
141 |
+
if self.clamp_max_for_overflow:
|
142 |
+
attn_weights_l = torch.clamp(attn_weights_l, max=50000) # Do not increase 50000, data type half has quite limited range
|
143 |
+
|
144 |
+
attn_weights_l = attn_weights_l.softmax(dim=-1)
|
145 |
+
# assert attention_mask_l.dtype == torch.int64
|
146 |
+
if attention_mask_l is not None:
|
147 |
+
assert (attention_mask_l.dim() == 2) # (bs, seq_len)
|
148 |
+
attention_mask = attention_mask_l.unsqueeze(1).unsqueeze(1) # (bs, 1, 1, seq_len)
|
149 |
+
attention_mask = attention_mask.expand(bsz, 1, tgt_len, src_len)
|
150 |
+
attention_mask = attention_mask.masked_fill(attention_mask == 0, -9e15)
|
151 |
+
|
152 |
+
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
153 |
+
raise ValueError(
|
154 |
+
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}"
|
155 |
+
)
|
156 |
+
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
157 |
+
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
158 |
+
|
159 |
+
attn_weights_v = nn.functional.softmax(attn_weights, dim=-1)
|
160 |
+
|
161 |
+
attn_probs_v = F.dropout(attn_weights_v, p=self.dropout, training=self.training)
|
162 |
+
attn_probs_l = F.dropout(attn_weights_l, p=self.dropout, training=self.training)
|
163 |
+
|
164 |
+
attn_output_v = torch.bmm(attn_probs_v, value_l_states)
|
165 |
+
attn_output_l = torch.bmm(attn_probs_l, value_v_states)
|
166 |
+
|
167 |
+
|
168 |
+
if attn_output_v.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
|
169 |
+
raise ValueError(
|
170 |
+
f"`attn_output_v` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output_v.size()}"
|
171 |
+
)
|
172 |
+
|
173 |
+
if attn_output_l.size() != (bsz * self.num_heads, src_len, self.head_dim):
|
174 |
+
raise ValueError(
|
175 |
+
f"`attn_output_l` should be of size {(bsz, self.num_heads, src_len, self.head_dim)}, but is {attn_output_l.size()}"
|
176 |
+
)
|
177 |
+
|
178 |
+
attn_output_v = attn_output_v.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
179 |
+
attn_output_v = attn_output_v.transpose(1, 2)
|
180 |
+
attn_output_v = attn_output_v.reshape(bsz, tgt_len, self.embed_dim)
|
181 |
+
|
182 |
+
attn_output_l = attn_output_l.view(bsz, self.num_heads, src_len, self.head_dim)
|
183 |
+
attn_output_l = attn_output_l.transpose(1, 2)
|
184 |
+
attn_output_l = attn_output_l.reshape(bsz, src_len, self.embed_dim)
|
185 |
+
|
186 |
+
attn_output_v = self.out_v_proj(attn_output_v)
|
187 |
+
attn_output_l = self.out_l_proj(attn_output_l)
|
188 |
+
|
189 |
+
return attn_output_v, attn_output_l
|
190 |
+
|
191 |
+
|
192 |
+
class BiAttentionBlockForCheckpoint(nn.Module):
|
193 |
+
def __init__(self, v_dim, l_dim, embed_dim, num_heads, dropout=0.1,
|
194 |
+
drop_path=.0, init_values=1e-4, ):
|
195 |
+
"""
|
196 |
+
Inputs:
|
197 |
+
embed_dim - Dimensionality of input and attention feature vectors
|
198 |
+
num_heads - Number of heads to use in the Multi-Head Attention block
|
199 |
+
dropout - Amount of dropout to apply in the feed-forward network
|
200 |
+
"""
|
201 |
+
super(BiAttentionBlockForCheckpoint, self).__init__()
|
202 |
+
|
203 |
+
# pre layer norm
|
204 |
+
self.layer_norm_v = nn.LayerNorm(v_dim)
|
205 |
+
self.layer_norm_l = nn.LayerNorm(l_dim)
|
206 |
+
self.attn = BiMultiHeadAttention(v_dim=v_dim,
|
207 |
+
l_dim=l_dim,
|
208 |
+
embed_dim=embed_dim,
|
209 |
+
num_heads=num_heads,
|
210 |
+
dropout=dropout,
|
211 |
+
)
|
212 |
+
|
213 |
+
# add layer scale for training stability
|
214 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
215 |
+
self.gamma_v = nn.Parameter(init_values * torch.ones((v_dim)), requires_grad=True)
|
216 |
+
self.gamma_l = nn.Parameter(init_values * torch.ones((l_dim)), requires_grad=True)
|
217 |
+
|
218 |
+
|
219 |
+
def forward(self, v, l, attention_mask_l=None, task=None):
|
220 |
+
# v: visual features, (bs, sigma(HW), 256)
|
221 |
+
# l: language features, (bs, seq_len, 768)
|
222 |
+
v = self.layer_norm_v(v)
|
223 |
+
l = self.layer_norm_l(l)
|
224 |
+
delta_v, delta_l = self.attn(v, l, attention_mask_l=attention_mask_l)
|
225 |
+
# v, l = v + delta_v, l + delta_l
|
226 |
+
v = v + self.drop_path(self.gamma_v * delta_v)
|
227 |
+
l = l + self.drop_path(self.gamma_l * delta_l)
|
228 |
+
return v, l
|
229 |
+
|
230 |
+
|
GLEE/glee/models/pixel_decoder/maskdino_encoder.py
ADDED
@@ -0,0 +1,463 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------
|
2 |
+
# DINO
|
3 |
+
# Copyright (c) 2022 IDEA. All Rights Reserved.
|
4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
5 |
+
# ------------------------------------------------------------------------
|
6 |
+
# Modified by Feng Li and Hao Zhang.
|
7 |
+
import logging
|
8 |
+
import numpy as np
|
9 |
+
from typing import Callable, Dict, List, Optional, Tuple, Union
|
10 |
+
import fvcore.nn.weight_init as weight_init
|
11 |
+
|
12 |
+
import torch
|
13 |
+
from torch import nn
|
14 |
+
from torch.nn import functional as F
|
15 |
+
from torch.nn.init import xavier_uniform_, constant_, uniform_, normal_
|
16 |
+
from torch.cuda.amp import autocast
|
17 |
+
|
18 |
+
from detectron2.config import configurable
|
19 |
+
from detectron2.layers import Conv2d, ShapeSpec, get_norm
|
20 |
+
from detectron2.modeling import SEM_SEG_HEADS_REGISTRY
|
21 |
+
|
22 |
+
from .position_encoding import PositionEmbeddingSine
|
23 |
+
from ...utils.utils import _get_clones, _get_clones_advanced, _get_activation_fn
|
24 |
+
from .ops.modules import MSDeformAttn
|
25 |
+
from .early_fusion import VLFuse
|
26 |
+
|
27 |
+
def build_pixel_decoder(cfg, input_shape):
|
28 |
+
"""
|
29 |
+
Build a pixel decoder from `cfg.MODEL.MaskDINO.PIXEL_DECODER_NAME`.
|
30 |
+
"""
|
31 |
+
name = cfg.MODEL.SEM_SEG_HEAD.PIXEL_DECODER_NAME
|
32 |
+
model = SEM_SEG_HEADS_REGISTRY.get(name)(cfg, input_shape)
|
33 |
+
forward_features = getattr(model, "forward_features", None)
|
34 |
+
if not callable(forward_features):
|
35 |
+
raise ValueError(
|
36 |
+
"Only SEM_SEG_HEADS with forward_features method can be used as pixel decoder. "
|
37 |
+
f"Please implement forward_features for {name} to only return mask features."
|
38 |
+
)
|
39 |
+
return model
|
40 |
+
|
41 |
+
|
42 |
+
# MSDeformAttn Transformer encoder in deformable detr
|
43 |
+
class MSDeformAttnTransformerEncoderOnly(nn.Module):
|
44 |
+
def __init__(self, d_model=256, nhead=8,
|
45 |
+
num_encoder_layers=6, dim_feedforward=1024, dropout=0.1,
|
46 |
+
activation="relu",
|
47 |
+
num_feature_levels=4, enc_n_points=4,):
|
48 |
+
super().__init__()
|
49 |
+
|
50 |
+
self.d_model = d_model
|
51 |
+
self.nhead = nhead
|
52 |
+
|
53 |
+
vl_fusion_layer = VLFuse()
|
54 |
+
|
55 |
+
encoder_layer = MSDeformAttnTransformerEncoderLayer(d_model, dim_feedforward,
|
56 |
+
dropout, activation,
|
57 |
+
num_feature_levels, nhead, enc_n_points)
|
58 |
+
self.encoder = MSDeformAttnTransformerEncoder(vl_fusion_layer, encoder_layer, num_encoder_layers)
|
59 |
+
|
60 |
+
self.level_embed = nn.Parameter(torch.Tensor(num_feature_levels, d_model))
|
61 |
+
|
62 |
+
self._reset_parameters()
|
63 |
+
|
64 |
+
def _reset_parameters(self):
|
65 |
+
for p in self.parameters():
|
66 |
+
if p.dim() > 1:
|
67 |
+
nn.init.xavier_uniform_(p)
|
68 |
+
for m in self.modules():
|
69 |
+
if isinstance(m, MSDeformAttn):
|
70 |
+
m._reset_parameters()
|
71 |
+
normal_(self.level_embed)
|
72 |
+
|
73 |
+
def get_valid_ratio(self, mask):
|
74 |
+
_, H, W = mask.shape
|
75 |
+
valid_H = torch.sum(~mask[:, :, 0], 1)
|
76 |
+
valid_W = torch.sum(~mask[:, 0, :], 1)
|
77 |
+
valid_ratio_h = valid_H.float() / H
|
78 |
+
valid_ratio_w = valid_W.float() / W
|
79 |
+
valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)
|
80 |
+
return valid_ratio
|
81 |
+
|
82 |
+
def forward(self, srcs, masks, pos_embeds, early_fusion=None):
|
83 |
+
|
84 |
+
enable_mask=0
|
85 |
+
if masks is not None:
|
86 |
+
for src in srcs:
|
87 |
+
if src.size(2)%32 or src.size(3)%32:
|
88 |
+
enable_mask = 1
|
89 |
+
if enable_mask==0:
|
90 |
+
masks = [torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool) for x in srcs]
|
91 |
+
# prepare input for encoder
|
92 |
+
src_flatten = []
|
93 |
+
mask_flatten = []
|
94 |
+
lvl_pos_embed_flatten = []
|
95 |
+
spatial_shapes = []
|
96 |
+
for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)):
|
97 |
+
bs, c, h, w = src.shape
|
98 |
+
spatial_shape = (h, w)
|
99 |
+
spatial_shapes.append(spatial_shape)
|
100 |
+
src = src.flatten(2).transpose(1, 2)
|
101 |
+
mask = mask.flatten(1)
|
102 |
+
pos_embed = pos_embed.flatten(2).transpose(1, 2)
|
103 |
+
lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1)
|
104 |
+
lvl_pos_embed_flatten.append(lvl_pos_embed)
|
105 |
+
src_flatten.append(src)
|
106 |
+
mask_flatten.append(mask)
|
107 |
+
src_flatten = torch.cat(src_flatten, 1)
|
108 |
+
mask_flatten = torch.cat(mask_flatten, 1)
|
109 |
+
lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)
|
110 |
+
spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=src_flatten.device)
|
111 |
+
level_start_index = torch.cat((spatial_shapes.new_zeros((1, )), spatial_shapes.prod(1).cumsum(0)[:-1]))
|
112 |
+
valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1)
|
113 |
+
# encoder
|
114 |
+
memory, zero_loss = self.encoder(src_flatten, spatial_shapes, level_start_index, valid_ratios, lvl_pos_embed_flatten, mask_flatten, early_fusion)
|
115 |
+
|
116 |
+
return memory, spatial_shapes, level_start_index, zero_loss
|
117 |
+
|
118 |
+
|
119 |
+
class MSDeformAttnTransformerEncoderLayer(nn.Module):
|
120 |
+
def __init__(self,
|
121 |
+
d_model=256, d_ffn=1024,
|
122 |
+
dropout=0.1, activation="relu",
|
123 |
+
n_levels=4, n_heads=8, n_points=4):
|
124 |
+
super().__init__()
|
125 |
+
|
126 |
+
# self attention
|
127 |
+
self.self_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points)
|
128 |
+
self.dropout1 = nn.Dropout(dropout)
|
129 |
+
self.norm1 = nn.LayerNorm(d_model)
|
130 |
+
|
131 |
+
# ffn
|
132 |
+
self.linear1 = nn.Linear(d_model, d_ffn)
|
133 |
+
self.activation = _get_activation_fn(activation)
|
134 |
+
self.dropout2 = nn.Dropout(dropout)
|
135 |
+
self.linear2 = nn.Linear(d_ffn, d_model)
|
136 |
+
self.dropout3 = nn.Dropout(dropout)
|
137 |
+
self.norm2 = nn.LayerNorm(d_model)
|
138 |
+
|
139 |
+
@staticmethod
|
140 |
+
def with_pos_embed(tensor, pos):
|
141 |
+
return tensor if pos is None else tensor + pos
|
142 |
+
|
143 |
+
def forward_ffn(self, src):
|
144 |
+
src2 = self.linear2(self.dropout2(self.activation(self.linear1(src))))
|
145 |
+
src = src + self.dropout3(src2)
|
146 |
+
src = self.norm2(src)
|
147 |
+
return src
|
148 |
+
|
149 |
+
def forward(self, src, pos, reference_points, spatial_shapes, level_start_index, padding_mask=None):
|
150 |
+
# self attention
|
151 |
+
src2 = self.self_attn(self.with_pos_embed(src, pos), reference_points, src, spatial_shapes, level_start_index, padding_mask)
|
152 |
+
src = src + self.dropout1(src2)
|
153 |
+
src = self.norm1(src)
|
154 |
+
|
155 |
+
# ffn
|
156 |
+
src = self.forward_ffn(src)
|
157 |
+
|
158 |
+
return src
|
159 |
+
|
160 |
+
|
161 |
+
class MSDeformAttnTransformerEncoder(nn.Module):
|
162 |
+
def __init__(self, vl_fusion_layer, encoder_layer, num_layers):
|
163 |
+
super().__init__()
|
164 |
+
self.layers = _get_clones(encoder_layer, num_layers)
|
165 |
+
self.num_layers = num_layers
|
166 |
+
|
167 |
+
self.vl_layers = _get_clones_advanced(vl_fusion_layer, num_layers, 1)
|
168 |
+
|
169 |
+
|
170 |
+
@staticmethod
|
171 |
+
def get_reference_points(spatial_shapes, valid_ratios, device):
|
172 |
+
reference_points_list = []
|
173 |
+
for lvl, (H_, W_) in enumerate(spatial_shapes):
|
174 |
+
|
175 |
+
ref_y, ref_x = torch.meshgrid(torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device),
|
176 |
+
torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device))
|
177 |
+
ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_)
|
178 |
+
ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_)
|
179 |
+
ref = torch.stack((ref_x, ref_y), -1)
|
180 |
+
reference_points_list.append(ref)
|
181 |
+
reference_points = torch.cat(reference_points_list, 1)
|
182 |
+
reference_points = reference_points[:, :, None] * valid_ratios[:, None]
|
183 |
+
return reference_points
|
184 |
+
|
185 |
+
def forward(self, src, spatial_shapes, level_start_index, valid_ratios, pos=None, padding_mask=None, early_fusion=None):
|
186 |
+
|
187 |
+
if early_fusion:
|
188 |
+
output = {"visual": src, "lang": early_fusion}
|
189 |
+
else:
|
190 |
+
output = src
|
191 |
+
|
192 |
+
reference_points = self.get_reference_points(spatial_shapes, valid_ratios, device=src.device)
|
193 |
+
for _, (layer,vl_layer) in enumerate(zip(self.layers, self.vl_layers)):
|
194 |
+
if early_fusion:
|
195 |
+
output = vl_layer(output)
|
196 |
+
output["visual"] = layer(output["visual"], pos, reference_points, spatial_shapes, level_start_index, padding_mask)
|
197 |
+
else:
|
198 |
+
output = layer(output, pos, reference_points, spatial_shapes, level_start_index, padding_mask)
|
199 |
+
|
200 |
+
|
201 |
+
if early_fusion:
|
202 |
+
return output["visual"] , (output['lang']['hidden']*0).sum()
|
203 |
+
else:
|
204 |
+
return output, None
|
205 |
+
|
206 |
+
|
207 |
+
@SEM_SEG_HEADS_REGISTRY.register()
|
208 |
+
class MaskDINOEncoder(nn.Module):
|
209 |
+
"""
|
210 |
+
This is the multi-scale encoder in detection models, also named as pixel decoder in segmentation models.
|
211 |
+
"""
|
212 |
+
@configurable
|
213 |
+
def __init__(
|
214 |
+
self,
|
215 |
+
input_shape: Dict[str, ShapeSpec],
|
216 |
+
*,
|
217 |
+
transformer_dropout: float,
|
218 |
+
transformer_nheads: int,
|
219 |
+
transformer_dim_feedforward: int,
|
220 |
+
transformer_enc_layers: int,
|
221 |
+
conv_dim: int,
|
222 |
+
mask_dim: int,
|
223 |
+
norm: Optional[Union[str, Callable]] = None,
|
224 |
+
# deformable transformer encoder args
|
225 |
+
transformer_in_features: List[str],
|
226 |
+
common_stride: int,
|
227 |
+
num_feature_levels: int,
|
228 |
+
total_num_feature_levels: int,
|
229 |
+
feature_order: str,
|
230 |
+
ViTBackbone: bool,
|
231 |
+
):
|
232 |
+
"""
|
233 |
+
NOTE: this interface is experimental.
|
234 |
+
Args:
|
235 |
+
input_shape: shapes (channels and stride) of the input features
|
236 |
+
transformer_dropout: dropout probability in transformer
|
237 |
+
transformer_nheads: number of heads in transformer
|
238 |
+
transformer_dim_feedforward: dimension of feedforward network
|
239 |
+
transformer_enc_layers: number of transformer encoder layers
|
240 |
+
conv_dims: number of output channels for the intermediate conv layers.
|
241 |
+
mask_dim: number of output channels for the final conv layer.
|
242 |
+
norm (str or callable): normalization for all conv layers
|
243 |
+
num_feature_levels: feature scales used
|
244 |
+
total_num_feature_levels: total feautre scales used (include the downsampled features)
|
245 |
+
feature_order: 'low2high' or 'high2low', i.e., 'low2high' means low-resolution features are put in the first.
|
246 |
+
"""
|
247 |
+
super().__init__()
|
248 |
+
transformer_input_shape = {
|
249 |
+
k: v for k, v in input_shape.items() if k in transformer_in_features
|
250 |
+
}
|
251 |
+
# this is the input shape of pixel decoder
|
252 |
+
input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride)
|
253 |
+
self.in_features = [k for k, v in input_shape] # starting from "res2" to "res5"
|
254 |
+
self.feature_strides = [v.stride for k, v in input_shape]
|
255 |
+
self.feature_channels = [v.channels for k, v in input_shape]
|
256 |
+
self.feature_order = feature_order
|
257 |
+
|
258 |
+
if feature_order == "low2high":
|
259 |
+
transformer_input_shape = sorted(transformer_input_shape.items(), key=lambda x: -x[1].stride)
|
260 |
+
else:
|
261 |
+
transformer_input_shape = sorted(transformer_input_shape.items(), key=lambda x: x[1].stride)
|
262 |
+
self.transformer_in_features = [k for k, v in transformer_input_shape] # starting from "res2" to "res5"
|
263 |
+
transformer_in_channels = [v.channels for k, v in transformer_input_shape]
|
264 |
+
self.transformer_feature_strides = [v.stride for k, v in transformer_input_shape] # to decide extra FPN layers
|
265 |
+
self.maskdino_num_feature_levels = num_feature_levels # always use 3 scales
|
266 |
+
self.total_num_feature_levels = total_num_feature_levels
|
267 |
+
self.common_stride = common_stride
|
268 |
+
|
269 |
+
self.transformer_num_feature_levels = len(self.transformer_in_features)
|
270 |
+
self.low_resolution_index = transformer_in_channels.index(max(transformer_in_channels))
|
271 |
+
self.high_resolution_index = 0 if self.feature_order == 'low2high' else -1
|
272 |
+
|
273 |
+
self.isViTBackbone = ViTBackbone
|
274 |
+
if not ViTBackbone:
|
275 |
+
if self.transformer_num_feature_levels > 1:
|
276 |
+
input_proj_list = []
|
277 |
+
for in_channels in transformer_in_channels[::-1]:
|
278 |
+
input_proj_list.append(nn.Sequential(
|
279 |
+
nn.Conv2d(in_channels, conv_dim, kernel_size=1),
|
280 |
+
nn.GroupNorm(32, conv_dim),
|
281 |
+
))
|
282 |
+
# input projectino for downsample
|
283 |
+
in_channels = max(transformer_in_channels)
|
284 |
+
for _ in range(self.total_num_feature_levels - self.transformer_num_feature_levels): # exclude the res2
|
285 |
+
input_proj_list.append(nn.Sequential(
|
286 |
+
nn.Conv2d(in_channels, conv_dim, kernel_size=3, stride=2, padding=1),
|
287 |
+
nn.GroupNorm(32, conv_dim),
|
288 |
+
))
|
289 |
+
in_channels = conv_dim
|
290 |
+
self.input_proj = nn.ModuleList(input_proj_list)
|
291 |
+
else:
|
292 |
+
self.input_proj = nn.ModuleList([
|
293 |
+
nn.Sequential(
|
294 |
+
nn.Conv2d(transformer_in_channels[-1], conv_dim, kernel_size=1),
|
295 |
+
nn.GroupNorm(32, conv_dim),
|
296 |
+
)])
|
297 |
+
|
298 |
+
for proj in self.input_proj:
|
299 |
+
nn.init.xavier_uniform_(proj[0].weight, gain=1)
|
300 |
+
nn.init.constant_(proj[0].bias, 0)
|
301 |
+
|
302 |
+
self.transformer = MSDeformAttnTransformerEncoderOnly(
|
303 |
+
d_model=conv_dim,
|
304 |
+
dropout=transformer_dropout,
|
305 |
+
nhead=transformer_nheads,
|
306 |
+
dim_feedforward=transformer_dim_feedforward,
|
307 |
+
num_encoder_layers=transformer_enc_layers,
|
308 |
+
num_feature_levels=self.total_num_feature_levels,
|
309 |
+
)
|
310 |
+
N_steps = conv_dim // 2
|
311 |
+
self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True)
|
312 |
+
|
313 |
+
self.mask_dim = mask_dim
|
314 |
+
# use 1x1 conv instead
|
315 |
+
self.mask_features = Conv2d(
|
316 |
+
conv_dim,
|
317 |
+
mask_dim,
|
318 |
+
kernel_size=1,
|
319 |
+
stride=1,
|
320 |
+
padding=0,
|
321 |
+
)
|
322 |
+
weight_init.c2_xavier_fill(self.mask_features)
|
323 |
+
# extra fpn levels
|
324 |
+
stride = min(self.transformer_feature_strides)
|
325 |
+
self.num_fpn_levels = max(int(np.log2(stride) - np.log2(self.common_stride)), 1)
|
326 |
+
|
327 |
+
lateral_convs = []
|
328 |
+
output_convs = []
|
329 |
+
|
330 |
+
use_bias = norm == ""
|
331 |
+
for idx, in_channels in enumerate(self.feature_channels[:self.num_fpn_levels]):
|
332 |
+
lateral_norm = get_norm(norm, conv_dim)
|
333 |
+
output_norm = get_norm(norm, conv_dim)
|
334 |
+
|
335 |
+
lateral_conv = Conv2d(
|
336 |
+
in_channels, conv_dim, kernel_size=1, bias=use_bias, norm=lateral_norm
|
337 |
+
)
|
338 |
+
output_conv = Conv2d(
|
339 |
+
conv_dim,
|
340 |
+
conv_dim,
|
341 |
+
kernel_size=3,
|
342 |
+
stride=1,
|
343 |
+
padding=1,
|
344 |
+
bias=use_bias,
|
345 |
+
norm=output_norm,
|
346 |
+
activation=F.relu,
|
347 |
+
)
|
348 |
+
weight_init.c2_xavier_fill(lateral_conv)
|
349 |
+
weight_init.c2_xavier_fill(output_conv)
|
350 |
+
self.add_module("adapter_{}".format(idx + 1), lateral_conv)
|
351 |
+
self.add_module("layer_{}".format(idx + 1), output_conv)
|
352 |
+
|
353 |
+
lateral_convs.append(lateral_conv)
|
354 |
+
output_convs.append(output_conv)
|
355 |
+
# Place convs into top-down order (from low to high resolution)
|
356 |
+
# to make the top-down computation in forward clearer.
|
357 |
+
self.lateral_convs = lateral_convs[::-1]
|
358 |
+
self.output_convs = output_convs[::-1]
|
359 |
+
|
360 |
+
@classmethod
|
361 |
+
def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]):
|
362 |
+
ret = {}
|
363 |
+
ret["input_shape"] = {
|
364 |
+
k: v for k, v in input_shape.items() if k in cfg.MODEL.SEM_SEG_HEAD.IN_FEATURES
|
365 |
+
}
|
366 |
+
ret["conv_dim"] = cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM
|
367 |
+
ret["mask_dim"] = cfg.MODEL.SEM_SEG_HEAD.MASK_DIM
|
368 |
+
ret["norm"] = cfg.MODEL.SEM_SEG_HEAD.NORM
|
369 |
+
ret["transformer_dropout"] = cfg.MODEL.MaskDINO.DROPOUT
|
370 |
+
ret["transformer_nheads"] = cfg.MODEL.MaskDINO.NHEADS
|
371 |
+
ret["transformer_dim_feedforward"] = cfg.MODEL.SEM_SEG_HEAD.DIM_FEEDFORWARD # deformable transformer encoder
|
372 |
+
ret[
|
373 |
+
"transformer_enc_layers"
|
374 |
+
] = cfg.MODEL.SEM_SEG_HEAD.TRANSFORMER_ENC_LAYERS # a separate config
|
375 |
+
ret["transformer_in_features"] = cfg.MODEL.SEM_SEG_HEAD.DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES # ['res3', 'res4', 'res5']
|
376 |
+
ret["common_stride"] = cfg.MODEL.SEM_SEG_HEAD.COMMON_STRIDE
|
377 |
+
ret["total_num_feature_levels"] = cfg.MODEL.SEM_SEG_HEAD.TOTAL_NUM_FEATURE_LEVELS
|
378 |
+
ret["num_feature_levels"] = cfg.MODEL.SEM_SEG_HEAD.NUM_FEATURE_LEVELS
|
379 |
+
ret["feature_order"] = cfg.MODEL.SEM_SEG_HEAD.FEATURE_ORDER
|
380 |
+
ret["ViTBackbone"] = cfg.MODEL.BACKBONE.NAME in ['D2_EVA02', 'D2_EVA01' , 'D2_ViT']
|
381 |
+
return ret
|
382 |
+
|
383 |
+
@autocast(enabled=False)
|
384 |
+
def forward_features(self, features, masks, early_fusion=None):
|
385 |
+
"""
|
386 |
+
:param features: multi-scale features from the backbone
|
387 |
+
:param masks: image mask
|
388 |
+
:return: enhanced multi-scale features and mask feature (1/4 resolution) for the decoder to produce binary mask
|
389 |
+
"""
|
390 |
+
# backbone features
|
391 |
+
srcs = []
|
392 |
+
pos = []
|
393 |
+
# additional downsampled features
|
394 |
+
srcsl = []
|
395 |
+
posl = []
|
396 |
+
|
397 |
+
if self.isViTBackbone:
|
398 |
+
for idx, f in enumerate(self.transformer_in_features[::-1]):
|
399 |
+
x = features[f].float() # deformable detr does not support half precision
|
400 |
+
srcs.append(x)
|
401 |
+
pos.append(self.pe_layer(x))
|
402 |
+
if self.feature_order != 'low2high':
|
403 |
+
srcs = srcs[::-1]
|
404 |
+
pos = pos[::-1]
|
405 |
+
else:
|
406 |
+
if self.total_num_feature_levels > self.transformer_num_feature_levels:
|
407 |
+
smallest_feat = features[self.transformer_in_features[self.low_resolution_index]].float()
|
408 |
+
_len_srcs = self.transformer_num_feature_levels
|
409 |
+
for l in range(_len_srcs, self.total_num_feature_levels):
|
410 |
+
if l == _len_srcs:
|
411 |
+
src = self.input_proj[l](smallest_feat)
|
412 |
+
else:
|
413 |
+
src = self.input_proj[l](srcsl[-1])
|
414 |
+
srcsl.append(src)
|
415 |
+
posl.append(self.pe_layer(src))
|
416 |
+
srcsl = srcsl[::-1]
|
417 |
+
# Reverse feature maps
|
418 |
+
|
419 |
+
|
420 |
+
for idx, f in enumerate(self.transformer_in_features[::-1]):
|
421 |
+
x = features[f].float() # deformable detr does not support half precision
|
422 |
+
srcs.append(self.input_proj[idx](x))
|
423 |
+
pos.append(self.pe_layer(x))
|
424 |
+
srcs.extend(srcsl) if self.feature_order == 'low2high' else srcsl.extend(srcs)
|
425 |
+
pos.extend(posl) if self.feature_order == 'low2high' else posl.extend(pos)
|
426 |
+
if self.feature_order != 'low2high':
|
427 |
+
srcs = srcsl
|
428 |
+
pos = posl
|
429 |
+
|
430 |
+
y, spatial_shapes, level_start_index, zero_loss = self.transformer(srcs, masks, pos, early_fusion)
|
431 |
+
bs = y.shape[0]
|
432 |
+
|
433 |
+
split_size_or_sections = [None] * self.total_num_feature_levels
|
434 |
+
for i in range(self.total_num_feature_levels):
|
435 |
+
if i < self.total_num_feature_levels - 1:
|
436 |
+
split_size_or_sections[i] = level_start_index[i + 1] - level_start_index[i]
|
437 |
+
else:
|
438 |
+
split_size_or_sections[i] = y.shape[1] - level_start_index[i]
|
439 |
+
y = torch.split(y, split_size_or_sections, dim=1)
|
440 |
+
|
441 |
+
out = []
|
442 |
+
multi_scale_features = []
|
443 |
+
num_cur_levels = 0
|
444 |
+
for i, z in enumerate(y):
|
445 |
+
out.append(z.transpose(1, 2).view(bs, -1, spatial_shapes[i][0], spatial_shapes[i][1]))
|
446 |
+
|
447 |
+
# append `out` with extra FPN levels
|
448 |
+
# Reverse feature maps into top-down order (from low to high resolution)
|
449 |
+
for idx, f in enumerate(self.in_features[:self.num_fpn_levels][::-1]):
|
450 |
+
x = features[f].float()
|
451 |
+
lateral_conv = self.lateral_convs[idx]
|
452 |
+
output_conv = self.output_convs[idx]
|
453 |
+
cur_fpn = lateral_conv(x)
|
454 |
+
# Following FPN implementation, we use nearest upsampling here
|
455 |
+
y = cur_fpn + F.interpolate(out[self.high_resolution_index], size=cur_fpn.shape[-2:], mode="bilinear", align_corners=False)
|
456 |
+
y = output_conv(y)
|
457 |
+
out.append(y)
|
458 |
+
for o in out:
|
459 |
+
if num_cur_levels < self.total_num_feature_levels:
|
460 |
+
multi_scale_features.append(o)
|
461 |
+
num_cur_levels += 1
|
462 |
+
return self.mask_features(out[-1]), out[0], multi_scale_features, zero_loss
|
463 |
+
|
GLEE/glee/models/pixel_decoder/ops/functions/__init__.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------------------------------
|
2 |
+
# Deformable DETR
|
3 |
+
# Copyright (c) 2020 SenseTime. All Rights Reserved.
|
4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
5 |
+
# ------------------------------------------------------------------------------------------------
|
6 |
+
# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
7 |
+
# ------------------------------------------------------------------------------------------------
|
8 |
+
|
9 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
10 |
+
# Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
|
11 |
+
|
12 |
+
from .ms_deform_attn_func import MSDeformAttnFunction
|
13 |
+
|
GLEE/glee/models/pixel_decoder/ops/functions/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (262 Bytes). View file
|
|
GLEE/glee/models/pixel_decoder/ops/functions/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (232 Bytes). View file
|
|
GLEE/glee/models/pixel_decoder/ops/functions/__pycache__/ms_deform_attn_func.cpython-38.pyc
ADDED
Binary file (2.7 kB). View file
|
|
GLEE/glee/models/pixel_decoder/ops/functions/__pycache__/ms_deform_attn_func.cpython-39.pyc
ADDED
Binary file (2.64 kB). View file
|
|
GLEE/glee/models/pixel_decoder/ops/functions/ms_deform_attn_func.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------------------------------
|
2 |
+
# Deformable DETR
|
3 |
+
# Copyright (c) 2020 SenseTime. All Rights Reserved.
|
4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
5 |
+
# ------------------------------------------------------------------------------------------------
|
6 |
+
# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
7 |
+
# ------------------------------------------------------------------------------------------------
|
8 |
+
|
9 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
10 |
+
# Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
|
11 |
+
|
12 |
+
from __future__ import absolute_import
|
13 |
+
from __future__ import print_function
|
14 |
+
from __future__ import division
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import torch.nn.functional as F
|
18 |
+
from torch.autograd import Function
|
19 |
+
from torch.autograd.function import once_differentiable
|
20 |
+
|
21 |
+
try:
|
22 |
+
import MultiScaleDeformableAttention as MSDA
|
23 |
+
except ModuleNotFoundError as e:
|
24 |
+
info_string = (
|
25 |
+
"\n\nPlease compile MultiScaleDeformableAttention CUDA op with the following commands:\n"
|
26 |
+
"\t`cd maskdino/modeling/pixel_decoder/ops`\n"
|
27 |
+
"\t`sh make.sh`\n"
|
28 |
+
)
|
29 |
+
# raise ModuleNotFoundError(info_string)
|
30 |
+
|
31 |
+
|
32 |
+
class MSDeformAttnFunction(Function):
|
33 |
+
@staticmethod
|
34 |
+
def forward(ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step):
|
35 |
+
ctx.im2col_step = im2col_step
|
36 |
+
output = MSDA.ms_deform_attn_forward(
|
37 |
+
value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, ctx.im2col_step)
|
38 |
+
ctx.save_for_backward(value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights)
|
39 |
+
return output
|
40 |
+
|
41 |
+
@staticmethod
|
42 |
+
@once_differentiable
|
43 |
+
def backward(ctx, grad_output):
|
44 |
+
value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights = ctx.saved_tensors
|
45 |
+
grad_value, grad_sampling_loc, grad_attn_weight = \
|
46 |
+
MSDA.ms_deform_attn_backward(
|
47 |
+
value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, grad_output, ctx.im2col_step)
|
48 |
+
|
49 |
+
return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None
|
50 |
+
|
51 |
+
|
52 |
+
def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights):
|
53 |
+
# for debug and test only,
|
54 |
+
# need to use cuda version instead
|
55 |
+
N_, S_, M_, D_ = value.shape
|
56 |
+
_, Lq_, M_, L_, P_, _ = sampling_locations.shape
|
57 |
+
value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1)
|
58 |
+
sampling_grids = 2 * sampling_locations - 1
|
59 |
+
sampling_value_list = []
|
60 |
+
for lid_, (H_, W_) in enumerate(value_spatial_shapes):
|
61 |
+
# N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_
|
62 |
+
value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_*M_, D_, H_, W_)
|
63 |
+
# N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2
|
64 |
+
sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1)
|
65 |
+
# N_*M_, D_, Lq_, P_
|
66 |
+
sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_,
|
67 |
+
mode='bilinear', padding_mode='zeros', align_corners=False)
|
68 |
+
sampling_value_list.append(sampling_value_l_)
|
69 |
+
# (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_)
|
70 |
+
attention_weights = attention_weights.transpose(1, 2).reshape(N_*M_, 1, Lq_, L_*P_)
|
71 |
+
output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_*D_, Lq_)
|
72 |
+
return output.transpose(1, 2).contiguous()
|