fffiloni commited on
Commit
353118d
·
1 Parent(s): a2f8bf2

Upload 5 files

Browse files
xdecoder/body/decoder/build.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .registry import model_entrypoints
2
+ from .registry import is_model
3
+
4
+ from .xdecoder import *
5
+
6
+ def build_decoder(config, *args, **kwargs):
7
+ model_name = config['MODEL']['DECODER']['NAME']
8
+
9
+ if not is_model(model_name):
10
+ raise ValueError(f'Unkown model: {model_name}')
11
+
12
+ return model_entrypoints(model_name)(config, *args, **kwargs)
xdecoder/body/decoder/registry.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _model_entrypoints = {}
2
+
3
+ def register_decoder(fn):
4
+ module_name_split = fn.__module__.split('.')
5
+ model_name = module_name_split[-1]
6
+ _model_entrypoints[model_name] = fn
7
+ return fn
8
+
9
+ def model_entrypoints(model_name):
10
+ return _model_entrypoints[model_name]
11
+
12
+ def is_model(model_name):
13
+ return model_name in _model_entrypoints
xdecoder/body/decoder/tmp.py ADDED
@@ -0,0 +1,664 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/detr.py
3
+ import logging
4
+ from typing import Optional
5
+
6
+ import torch
7
+ from torch import nn, Tensor
8
+ from torch.nn import functional as F
9
+
10
+ from timm.models.layers import trunc_normal_
11
+ from detectron2.layers import Conv2d
12
+ import fvcore.nn.weight_init as weight_init
13
+
14
+ from .registry import register_decoder
15
+ from ...utils import configurable
16
+ from ...modules import PositionEmbeddingSine
17
+
18
+ from image2html.visualizer import VL
19
+
20
+
21
+ class SelfAttentionLayer(nn.Module):
22
+
23
+ def __init__(self, d_model, nhead, dropout=0.0,
24
+ activation="relu", normalize_before=False):
25
+ super().__init__()
26
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
27
+
28
+ self.norm = nn.LayerNorm(d_model)
29
+ self.dropout = nn.Dropout(dropout)
30
+
31
+ self.activation = _get_activation_fn(activation)
32
+ self.normalize_before = normalize_before
33
+
34
+ self._reset_parameters()
35
+
36
+ def _reset_parameters(self):
37
+ for p in self.parameters():
38
+ if p.dim() > 1:
39
+ nn.init.xavier_uniform_(p)
40
+
41
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
42
+ return tensor if pos is None else tensor + pos
43
+
44
+ def forward_post(self, tgt,
45
+ tgt_mask: Optional[Tensor] = None,
46
+ tgt_key_padding_mask: Optional[Tensor] = None,
47
+ query_pos: Optional[Tensor] = None):
48
+ q = k = self.with_pos_embed(tgt, query_pos)
49
+ tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,
50
+ key_padding_mask=tgt_key_padding_mask)[0]
51
+ tgt = tgt + self.dropout(tgt2)
52
+ tgt = self.norm(tgt)
53
+
54
+ return tgt
55
+
56
+ def forward_pre(self, tgt,
57
+ tgt_mask: Optional[Tensor] = None,
58
+ tgt_key_padding_mask: Optional[Tensor] = None,
59
+ query_pos: Optional[Tensor] = None):
60
+ tgt2 = self.norm(tgt)
61
+ q = k = self.with_pos_embed(tgt2, query_pos)
62
+ tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
63
+ key_padding_mask=tgt_key_padding_mask)[0]
64
+ tgt = tgt + self.dropout(tgt2)
65
+
66
+ return tgt
67
+
68
+ def forward(self, tgt,
69
+ tgt_mask: Optional[Tensor] = None,
70
+ tgt_key_padding_mask: Optional[Tensor] = None,
71
+ query_pos: Optional[Tensor] = None):
72
+ if self.normalize_before:
73
+ return self.forward_pre(tgt, tgt_mask,
74
+ tgt_key_padding_mask, query_pos)
75
+ return self.forward_post(tgt, tgt_mask,
76
+ tgt_key_padding_mask, query_pos)
77
+
78
+
79
+ class CrossAttentionLayer(nn.Module):
80
+
81
+ def __init__(self, d_model, nhead, dropout=0.0,
82
+ activation="relu", normalize_before=False):
83
+ super().__init__()
84
+ self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
85
+
86
+ self.norm = nn.LayerNorm(d_model)
87
+ self.dropout = nn.Dropout(dropout)
88
+
89
+ self.activation = _get_activation_fn(activation)
90
+ self.normalize_before = normalize_before
91
+
92
+ self._reset_parameters()
93
+
94
+ def _reset_parameters(self):
95
+ for p in self.parameters():
96
+ if p.dim() > 1:
97
+ nn.init.xavier_uniform_(p)
98
+
99
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
100
+ return tensor if pos is None else tensor + pos
101
+
102
+ def forward_post(self, tgt, memory,
103
+ memory_mask: Optional[Tensor] = None,
104
+ memory_key_padding_mask: Optional[Tensor] = None,
105
+ pos: Optional[Tensor] = None,
106
+ query_pos: Optional[Tensor] = None):
107
+ tgt2, avg_attn = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
108
+ key=self.with_pos_embed(memory, pos),
109
+ value=memory, attn_mask=memory_mask,
110
+ key_padding_mask=memory_key_padding_mask)
111
+ tgt = tgt + self.dropout(tgt2)
112
+ tgt = self.norm(tgt)
113
+ return tgt, avg_attn
114
+
115
+ def forward_pre(self, tgt, memory,
116
+ memory_mask: Optional[Tensor] = None,
117
+ memory_key_padding_mask: Optional[Tensor] = None,
118
+ pos: Optional[Tensor] = None,
119
+ query_pos: Optional[Tensor] = None):
120
+ tgt2 = self.norm(tgt)
121
+ tgt2, avg_attn = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
122
+ key=self.with_pos_embed(memory, pos),
123
+ value=memory, attn_mask=memory_mask,
124
+ key_padding_mask=memory_key_padding_mask)
125
+ tgt = tgt + self.dropout(tgt2)
126
+
127
+ return tgt, avg_attn
128
+
129
+ def forward(self, tgt, memory,
130
+ memory_mask: Optional[Tensor] = None,
131
+ memory_key_padding_mask: Optional[Tensor] = None,
132
+ pos: Optional[Tensor] = None,
133
+ query_pos: Optional[Tensor] = None):
134
+ if self.normalize_before:
135
+ return self.forward_pre(tgt, memory, memory_mask,
136
+ memory_key_padding_mask, pos, query_pos)
137
+ return self.forward_post(tgt, memory, memory_mask,
138
+ memory_key_padding_mask, pos, query_pos)
139
+
140
+
141
+ class FFNLayer(nn.Module):
142
+
143
+ def __init__(self, d_model, dim_feedforward=2048, dropout=0.0,
144
+ activation="relu", normalize_before=False):
145
+ super().__init__()
146
+ # Implementation of Feedforward model
147
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
148
+ self.dropout = nn.Dropout(dropout)
149
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
150
+
151
+ self.norm = nn.LayerNorm(d_model)
152
+
153
+ self.activation = _get_activation_fn(activation)
154
+ self.normalize_before = normalize_before
155
+
156
+ self._reset_parameters()
157
+
158
+ def _reset_parameters(self):
159
+ for p in self.parameters():
160
+ if p.dim() > 1:
161
+ nn.init.xavier_uniform_(p)
162
+
163
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
164
+ return tensor if pos is None else tensor + pos
165
+
166
+ def forward_post(self, tgt):
167
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
168
+ tgt = tgt + self.dropout(tgt2)
169
+ tgt = self.norm(tgt)
170
+ return tgt
171
+
172
+ def forward_pre(self, tgt):
173
+ tgt2 = self.norm(tgt)
174
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
175
+ tgt = tgt + self.dropout(tgt2)
176
+ return tgt
177
+
178
+ def forward(self, tgt):
179
+ if self.normalize_before:
180
+ return self.forward_pre(tgt)
181
+ return self.forward_post(tgt)
182
+
183
+
184
+ def _get_activation_fn(activation):
185
+ """Return an activation function given a string"""
186
+ if activation == "relu":
187
+ return F.relu
188
+ if activation == "gelu":
189
+ return F.gelu
190
+ if activation == "glu":
191
+ return F.glu
192
+ raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
193
+
194
+
195
+ class MLP(nn.Module):
196
+ """ Very simple multi-layer perceptron (also called FFN)"""
197
+
198
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
199
+ super().__init__()
200
+ self.num_layers = num_layers
201
+ h = [hidden_dim] * (num_layers - 1)
202
+ self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
203
+
204
+ def forward(self, x):
205
+ for i, layer in enumerate(self.layers):
206
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
207
+ return x
208
+
209
+
210
+ class MultiScaleMaskedTransformerDecoder(nn.Module):
211
+
212
+ _version = 2
213
+
214
+ @configurable
215
+ def __init__(
216
+ self,
217
+ lang_encoder: nn.Module,
218
+ in_channels,
219
+ mask_classification=True,
220
+ *,
221
+ hidden_dim: int,
222
+ dim_proj: int,
223
+ num_queries: int,
224
+ contxt_len: int,
225
+ nheads: int,
226
+ dim_feedforward: int,
227
+ dec_layers: int,
228
+ pre_norm: bool,
229
+ mask_dim: int,
230
+ task_switch: dict,
231
+ captioning_step: int,
232
+ enforce_input_project: bool,
233
+ ):
234
+ """
235
+ NOTE: this interface is experimental.
236
+ Args:
237
+ in_channels: channels of the input features
238
+ mask_classification: whether to add mask classifier or not
239
+ num_classes: number of classes
240
+ hidden_dim: Transformer feature dimension
241
+ num_queries: number of queries
242
+ nheads: number of heads
243
+ dim_feedforward: feature dimension in feedforward network
244
+ enc_layers: number of Transformer encoder layers
245
+ dec_layers: number of Transformer decoder layers
246
+ pre_norm: whether to use pre-LayerNorm or not
247
+ mask_dim: mask feature dimension
248
+ enforce_input_project: add input project 1x1 conv even if input
249
+ channels and hidden dim is identical
250
+ """
251
+ super().__init__()
252
+ assert mask_classification, "Only support mask classification model"
253
+ self.mask_classification = mask_classification
254
+
255
+ # positional encoding
256
+ N_steps = hidden_dim // 2
257
+ self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True)
258
+
259
+ # define Transformer decoder here
260
+ self.num_heads = nheads
261
+ self.num_layers = dec_layers
262
+ self.contxt_len = contxt_len
263
+ self.transformer_self_attention_layers = nn.ModuleList()
264
+ self.transformer_cross_attention_layers = nn.ModuleList()
265
+ self.transformer_ffn_layers = nn.ModuleList()
266
+
267
+ for _ in range(self.num_layers):
268
+ self.transformer_self_attention_layers.append(
269
+ SelfAttentionLayer(
270
+ d_model=hidden_dim,
271
+ nhead=nheads,
272
+ dropout=0.0,
273
+ normalize_before=pre_norm,
274
+ )
275
+ )
276
+
277
+ self.transformer_cross_attention_layers.append(
278
+ CrossAttentionLayer(
279
+ d_model=hidden_dim,
280
+ nhead=nheads,
281
+ dropout=0.0,
282
+ normalize_before=pre_norm,
283
+ )
284
+ )
285
+
286
+ self.transformer_ffn_layers.append(
287
+ FFNLayer(
288
+ d_model=hidden_dim,
289
+ dim_feedforward=dim_feedforward,
290
+ dropout=0.0,
291
+ normalize_before=pre_norm,
292
+ )
293
+ )
294
+
295
+ self.decoder_norm = nn.LayerNorm(hidden_dim)
296
+
297
+ self.num_queries = num_queries
298
+ # learnable query features
299
+ self.query_feat = nn.Embedding(num_queries, hidden_dim)
300
+ # learnable query p.e.
301
+ self.query_embed = nn.Embedding(num_queries, hidden_dim)
302
+
303
+ # level embedding (we always use 3 scales)
304
+ self.num_feature_levels = 3
305
+ self.level_embed = nn.Embedding(self.num_feature_levels, hidden_dim)
306
+ self.input_proj = nn.ModuleList()
307
+
308
+ for _ in range(self.num_feature_levels):
309
+ if in_channels != hidden_dim or enforce_input_project:
310
+ self.input_proj.append(Conv2d(in_channels, hidden_dim, kernel_size=1))
311
+ weight_init.c2_xavier_fill(self.input_proj[-1])
312
+ else:
313
+ self.input_proj.append(nn.Sequential())
314
+
315
+ self.task_switch = task_switch
316
+
317
+ # output FFNs
318
+ self.lang_encoder = lang_encoder
319
+ if self.task_switch['mask']:
320
+ self.mask_embed = MLP(hidden_dim, hidden_dim, mask_dim, 3)
321
+
322
+ self.class_embed = nn.Parameter(torch.empty(hidden_dim, dim_proj))
323
+ trunc_normal_(self.class_embed, std=.02)
324
+
325
+ if task_switch['bbox']:
326
+ self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)
327
+
328
+ # Caption Project and query
329
+ if task_switch['captioning']:
330
+ self.caping_embed = nn.Parameter(torch.empty(hidden_dim, dim_proj))
331
+ trunc_normal_(self.caping_embed, std=.02)
332
+ self.query_feat_caping = nn.Embedding(contxt_len, hidden_dim)
333
+ self.captioning_step = captioning_step
334
+
335
+ # register self_attn_mask to avoid information leakage, it includes interaction between object query, class query and caping query
336
+ self_attn_mask = torch.zeros((1, num_queries + contxt_len, num_queries + contxt_len)).bool()
337
+ self_attn_mask[:, :num_queries, num_queries:] = True # object+class query does not attend with caption query.
338
+ self_attn_mask[:, num_queries:, num_queries:] = torch.triu(torch.ones((1, contxt_len, contxt_len)), diagonal=1).bool() # caption query only attend with previous token.
339
+ self_attn_mask[:, :num_queries-1, num_queries-1:num_queries] = True # object query does not attend with class query.
340
+ self_attn_mask[:, num_queries-1:num_queries, :num_queries-1] = True # class query does not attend with object query.
341
+ self.register_buffer("self_attn_mask", self_attn_mask)
342
+
343
+
344
+ @classmethod
345
+ def from_config(cls, cfg, in_channels, lang_encoder, mask_classification, extra):
346
+ ret = {}
347
+
348
+ ret["lang_encoder"] = lang_encoder
349
+ ret["in_channels"] = in_channels
350
+ ret["mask_classification"] = mask_classification
351
+
352
+ enc_cfg = cfg['MODEL']['ENCODER']
353
+ dec_cfg = cfg['MODEL']['DECODER']
354
+
355
+ ret["hidden_dim"] = dec_cfg['HIDDEN_DIM']
356
+ ret["dim_proj"] = cfg['MODEL']['DIM_PROJ']
357
+ ret["num_queries"] = dec_cfg['NUM_OBJECT_QUERIES']
358
+ ret["contxt_len"] = cfg['MODEL']['TEXT']['CONTEXT_LENGTH']
359
+
360
+ # Transformer parameters:
361
+ ret["nheads"] = dec_cfg['NHEADS']
362
+ ret["dim_feedforward"] = dec_cfg['DIM_FEEDFORWARD']
363
+
364
+ # NOTE: because we add learnable query features which requires supervision,
365
+ # we add minus 1 to decoder layers to be consistent with our loss
366
+ # implementation: that is, number of auxiliary losses is always
367
+ # equal to number of decoder layers. With learnable query features, the number of
368
+ # auxiliary losses equals number of decoders plus 1.
369
+ assert dec_cfg['DEC_LAYERS'] >= 1
370
+ ret["dec_layers"] = dec_cfg['DEC_LAYERS'] - 1
371
+ ret["pre_norm"] = dec_cfg['PRE_NORM']
372
+ ret["enforce_input_project"] = dec_cfg['ENFORCE_INPUT_PROJ']
373
+ ret["mask_dim"] = enc_cfg['MASK_DIM']
374
+
375
+ ret["task_switch"] = extra['task_switch']
376
+ ret["captioning_step"] = dec_cfg['CAPTIONING'].get('STEP', 50)
377
+
378
+ return ret
379
+
380
+ def forward(self, x, mask_features, mask=None, target_queries=None, target_vlp=None, task='seg', extra={}):
381
+ if task == 'captioning_infer':
382
+ return self.forward_captioning(x, mask_features, mask=mask, target_queries=target_queries, target_vlp=target_vlp, task=task, extra=extra)
383
+ # x is a list of multi-scale feature
384
+ assert len(x) == self.num_feature_levels
385
+ src = []
386
+ pos = []
387
+ size_list = []
388
+
389
+ # disable mask, it does not affect performance
390
+ del mask
391
+ for i in range(self.num_feature_levels):
392
+ size_list.append(x[i].shape[-2:])
393
+ pos.append(self.pe_layer(x[i], None).flatten(2))
394
+ src.append(self.input_proj[i](x[i]).flatten(2) + self.level_embed.weight[i][None, :, None])
395
+
396
+ # flatten NxCxHxW to HWxNxC
397
+ pos[-1] = pos[-1].permute(2, 0, 1)
398
+ src[-1] = src[-1].permute(2, 0, 1)
399
+
400
+ _, bs, _ = src[0].shape
401
+
402
+ # QxNxC
403
+ query_embed = self.query_embed.weight.unsqueeze(1).repeat(1, bs, 1)
404
+ output = self.query_feat.weight.unsqueeze(1).repeat(1, bs, 1)
405
+
406
+ predictions_class = []
407
+ predictions_mask = []
408
+ predictions_bbox = []
409
+ predictions_caption = []
410
+ predictions_captioning = []
411
+
412
+ self_tgt_mask = None
413
+ if self.training and task == 'vlp' and self.task_switch['captioning']:
414
+ output = torch.cat((output, self.query_feat_caping.weight.unsqueeze(1).repeat(1, bs, 1)), dim=0) # concat object query, class token and caption token.
415
+ caping_lang_embed = torch.cat([caption['caption_tokens'] for caption in target_vlp], dim=0).transpose(0, 1) # language output
416
+ query_embed = torch.cat((query_embed, caping_lang_embed), dim=0) # may not add at the beginning.
417
+ self_tgt_mask = self.self_attn_mask.repeat(output.shape[1]*self.num_heads, 1, 1)
418
+ elif (((self.training and task == 'seg') or (task == 'grounding_eval')) and self.task_switch['grounding']) \
419
+ or ((self.training and task == 'openimage') and self.task_switch['openimage']['grounding']):
420
+ self_tgt_mask = self.self_attn_mask[:,:self.num_queries,:self.num_queries].repeat(output.shape[1]*self.num_heads, 1, 1)
421
+ grounding_tokens = extra['grounding_tokens']
422
+ _grounding_tokens = grounding_tokens.detach().clone()
423
+ # initialize with negative attention at the beginning.
424
+ pad_tgt_mask = torch.ones((1, self.num_queries + (self.num_queries-1) + len(grounding_tokens), self.num_queries + (self.num_queries-1) + len(grounding_tokens)), device=self_tgt_mask.device).bool().repeat(output.shape[1]*self.num_heads, 1, 1)
425
+ pad_tgt_mask[:,:self.num_queries,:self.num_queries] = self_tgt_mask
426
+ pad_tgt_mask[:,self.num_queries:,self.num_queries:] = False # grounding tokens could attend with eatch other
427
+ self_tgt_mask = pad_tgt_mask
428
+ output = torch.cat((output, output[:-1]), dim=0)
429
+ query_embed = torch.cat((query_embed, query_embed[:-1]), dim=0) # also pad language embdding to fix embedding
430
+ else:
431
+ self_tgt_mask = self.self_attn_mask[:,:self.num_queries,:self.num_queries].repeat(output.shape[1]*self.num_heads, 1, 1)
432
+
433
+ # prediction heads on learnable query features
434
+ results = self.forward_prediction_heads(output, mask_features, attn_mask_target_size=size_list[0], task=task)
435
+ attn_mask = results["attn_mask"]
436
+ predictions_class.append(results["outputs_class"])
437
+ predictions_mask.append(results["outputs_mask"])
438
+ predictions_bbox.append(results["outputs_bbox"])
439
+ predictions_caption.append(results["outputs_caption"])
440
+ predictions_captioning.append(results["outputs_captionting"])
441
+
442
+ for i in range(self.num_layers):
443
+ level_index = i % self.num_feature_levels
444
+ attn_mask[torch.where(attn_mask.sum(-1) == attn_mask.shape[-1])] = False
445
+
446
+ if self.training and task == 'vlp' and self.task_switch['captioning']:
447
+ attn_mask = torch.cat((attn_mask, torch.zeros_like(attn_mask[:, :self.contxt_len, :])), dim=1)
448
+ # attention: cross-attention first
449
+ output, avg_attn = self.transformer_cross_attention_layers[i](
450
+ output, src[level_index],
451
+ memory_mask=attn_mask,
452
+ memory_key_padding_mask=None, # here we do not apply masking on padded region
453
+ pos=pos[level_index], query_pos=query_embed
454
+ )
455
+
456
+ if (((self.training and task == 'seg') or (task == 'grounding_eval')) and self.task_switch['grounding']) \
457
+ or ((self.training and task == 'openimage') and self.task_switch['openimage']['grounding']):
458
+ output = torch.cat((output, _grounding_tokens), dim=0)
459
+ query_embed = torch.cat((query_embed, grounding_tokens), dim=0)
460
+
461
+ output = self.transformer_self_attention_layers[i](
462
+ output, tgt_mask=self_tgt_mask,
463
+ tgt_key_padding_mask=None,
464
+ query_pos=query_embed
465
+ )
466
+
467
+ # FFN
468
+ output = self.transformer_ffn_layers[i](
469
+ output
470
+ )
471
+
472
+ if ((self.training and task == 'seg') or (task == 'grounding_eval')) and self.task_switch['grounding'] \
473
+ or ((self.training and task == 'openimage') and self.task_switch['openimage']['grounding']):
474
+ _grounding_tokens = output[-len(_grounding_tokens):]
475
+ output = output[:-len(_grounding_tokens)]
476
+ query_embed = query_embed[:-len(_grounding_tokens)]
477
+
478
+ results = self.forward_prediction_heads(output, mask_features, attn_mask_target_size=size_list[(i + 1) % self.num_feature_levels], layer_id=i, task=task)
479
+ attn_mask = results["attn_mask"]
480
+ predictions_class.append(results["outputs_class"])
481
+ predictions_mask.append(results["outputs_mask"])
482
+ predictions_bbox.append(results["outputs_bbox"])
483
+ predictions_caption.append(results["outputs_caption"])
484
+ predictions_captioning.append(results["outputs_captionting"])
485
+
486
+ assert len(predictions_class) == self.num_layers + 1
487
+ if task == 'vlp':
488
+ out = {'pred_captionings': predictions_captioning[-1],
489
+ 'pred_captions': predictions_caption[-1],
490
+ 'aux_outputs': [{'pred_captionings': x, 'pred_captions': y } for x, y in zip(predictions_captioning[:-1], predictions_caption[:-1])]}
491
+ return out
492
+ else:
493
+ out = {
494
+ 'pred_logits': predictions_class[-1],
495
+ 'pred_masks': predictions_mask[-1],
496
+ 'pred_boxes': predictions_bbox[-1],
497
+ 'pred_captions': predictions_caption[-1],
498
+ 'aux_outputs': self._set_aux_loss(
499
+ predictions_class if self.mask_classification else None, predictions_mask, predictions_bbox, predictions_caption
500
+ )
501
+ }
502
+ return out
503
+
504
+ def forward_captioning(self, x, mask_features, mask = None, target_queries = None, target_vlp = None, task='seg', extra={}):
505
+ # x is a list of multi-scale feature
506
+ assert len(x) == self.num_feature_levels
507
+ src = []
508
+ pos = []
509
+ size_list = []
510
+
511
+ # disable mask, it does not affect performance
512
+ del mask
513
+ for i in range(self.num_feature_levels):
514
+ size_list.append(x[i].shape[-2:])
515
+ pos.append(self.pe_layer(x[i], None).flatten(2))
516
+ src.append(self.input_proj[i](x[i]).flatten(2) + self.level_embed.weight[i][None, :, None])
517
+
518
+ # flatten NxCxHxW to HWxNxC
519
+ pos[-1] = pos[-1].permute(2, 0, 1)
520
+ src[-1] = src[-1].permute(2, 0, 1)
521
+
522
+ _, bs, _ = src[0].shape
523
+
524
+ # QxNxC
525
+ query_embed_ = self.query_embed.weight.unsqueeze(1).repeat(1, bs, 1)
526
+ query_feat = self.query_feat.weight.unsqueeze(1).repeat(1, bs, 1)
527
+ caping_lang_token = extra['start_token'].repeat(bs, 1)
528
+ query_feat_caping = self.query_feat_caping.weight.unsqueeze(1).repeat(1, bs, 1)
529
+
530
+ # prepare token embedding for evaluation
531
+ token_embs = self.lang_encoder.lang_encoder.token_embedding.weight
532
+ # token_embs = (token_embs / token_embs.norm(dim=-1, keepdim=True) + 1e-7)
533
+
534
+ for cap_idx in range(0, self.captioning_step):
535
+ caping_lang_embed = self.lang_encoder.forward_language_token((caping_lang_token,))[0].transpose(0, 1)
536
+ query_embed = torch.cat((query_embed_, caping_lang_embed), dim=0) # may not add at the beginning.
537
+ output = torch.cat((query_feat, query_feat_caping), dim=0) # concat object query, class token and caption token.
538
+
539
+ # prediction heads on learnable query features
540
+ results = self.forward_prediction_heads(output, mask_features, attn_mask_target_size=size_list[0], task=task)
541
+ attn_mask = results["attn_mask"]
542
+
543
+ for i in range(self.num_layers):
544
+ level_index = i % self.num_feature_levels
545
+ attn_mask[torch.where(attn_mask.sum(-1) == attn_mask.shape[-1])] = False
546
+ attn_mask = torch.cat((attn_mask, torch.zeros_like(attn_mask[:, :self.contxt_len, :])), dim=1)
547
+ self_tgt_mask = self.self_attn_mask.repeat(output.shape[1]*self.num_heads, 1, 1)
548
+
549
+ # attention: cross-attention first
550
+ output, avg_attn = self.transformer_cross_attention_layers[i](
551
+ output, src[level_index],
552
+ memory_mask=attn_mask,
553
+ memory_key_padding_mask=None, # here we do not apply masking on padded region
554
+ pos=pos[level_index], query_pos=query_embed
555
+ )
556
+
557
+ output = self.transformer_self_attention_layers[i](
558
+ output, tgt_mask=self_tgt_mask,
559
+ tgt_key_padding_mask=None,
560
+ query_pos=query_embed
561
+ )
562
+
563
+ # FFN
564
+ output = self.transformer_ffn_layers[i](
565
+ output
566
+ )
567
+
568
+ results = self.forward_prediction_heads(output, mask_features, attn_mask_target_size=size_list[(i + 1) % self.num_feature_levels], layer_id=i, task=task)
569
+ attn_mask = results["attn_mask"]
570
+
571
+ pred_captions_gen = results['outputs_captionting']
572
+ # pred_captions_gen = (pred_captions_gen / pred_captions_gen.norm(dim=-1, keepdim=True) + 1e-7)
573
+ pred_captions_gen = pred_captions_gen @ token_embs.t()
574
+ caping_lang_token[:,cap_idx+1] = pred_captions_gen[:,cap_idx].max(-1)[1]
575
+
576
+ out = {'pred_captionings': caping_lang_token,
577
+ 'pred_texts': self.lang_encoder.tokenizer.batch_decode(caping_lang_token, skip_special_tokens=True)}
578
+ return out
579
+
580
+
581
+ def forward_prediction_heads(self, output, mask_features, attn_mask_target_size, layer_id=-1, task='seg'):
582
+ decoder_output = self.decoder_norm(output)
583
+ decoder_output = decoder_output.transpose(0, 1)
584
+
585
+ # extract image captioning token from decoder output.
586
+ if self.task_switch['captioning'] and (task == 'vlp' or task == 'captioning_infer'):
587
+ outputs_captionting = decoder_output[:,self.num_queries:] @ self.caping_embed
588
+ else:
589
+ outputs_captionting = None
590
+
591
+ # recompute class token output.
592
+ norm_decoder_output = decoder_output / (decoder_output.norm(dim=-1, keepdim=True) + 1e-7)
593
+ obj_token = norm_decoder_output[:,:self.num_queries-1]
594
+ cls_token = norm_decoder_output[:,self.num_queries-1:self.num_queries]
595
+
596
+ sim = (cls_token @ obj_token.transpose(1,2)).softmax(-1)[:,0,:,None] # TODO include class token.
597
+ cls_token = (sim * decoder_output[:,:self.num_queries-1]).sum(dim=1, keepdim=True)
598
+
599
+ if (((self.training and task == 'seg') or (task == 'grounding_eval')) and self.task_switch['grounding']) \
600
+ or ((self.training and task == 'openimage') and self.task_switch['openimage']['grounding']):
601
+ decoder_output = torch.cat((decoder_output[:,:self.num_queries-1], cls_token, decoder_output[:,self.num_queries:2*self.num_queries-1]), dim=1)
602
+ else:
603
+ decoder_output = torch.cat((decoder_output[:,:self.num_queries-1], cls_token), dim=1)
604
+
605
+ # compute class, mask and bbox.
606
+ class_embed = decoder_output @ self.class_embed
607
+ # HACK do not compute similarity if mask is not on
608
+ outputs_class = self.lang_encoder.compute_similarity(class_embed, fake=(((not self.task_switch['mask']) and self.training) or (task == 'openimage')))
609
+
610
+ if self.task_switch['mask'] or self.task_switch['openimage']['mask']:
611
+ mask_embed = self.mask_embed(decoder_output)
612
+ outputs_mask = torch.einsum("bqc,bchw->bqhw", mask_embed, mask_features)
613
+
614
+ # NOTE: prediction is of higher-resolution
615
+ # [B, Q, H, W] -> [B, Q, H*W] -> [B, h, Q, H*W] -> [B*h, Q, HW]
616
+ attn_mask = F.interpolate(outputs_mask, size=attn_mask_target_size, mode="bilinear", align_corners=False)
617
+
618
+ # must use bool type
619
+ # If a BoolTensor is provided, positions with ``True`` are not allowed to attend while ``False`` values will be unchanged.
620
+ attn_mask = (attn_mask.sigmoid().flatten(2).unsqueeze(1).repeat(1, self.num_heads, 1, 1).flatten(0, 1) < 0.5).bool()
621
+ attn_mask = attn_mask.detach()
622
+
623
+ # NOTE: fill False for cls token (JY)
624
+ attn_mask[:, self.num_queries:self.num_queries+1].fill_(False)
625
+ else:
626
+ outputs_mask = None
627
+ attn_mask = torch.zeros((list(decoder_output.shape[:2]) + [attn_mask_target_size[0]*attn_mask_target_size[1]]), device=decoder_output.device).repeat(self.num_heads, 1, 1).bool()
628
+
629
+ outputs_bbox = [None for i in range(len(decoder_output))]
630
+ if self.task_switch['bbox']:
631
+ outputs_bbox = self.bbox_embed(decoder_output)
632
+
633
+ outputs_caption = None
634
+ if self.task_switch['caption']:
635
+ outputs_caption = class_embed
636
+
637
+
638
+ results = {
639
+ "outputs_class": outputs_class,
640
+ "outputs_mask": outputs_mask,
641
+ "outputs_bbox": outputs_bbox,
642
+ "attn_mask": attn_mask,
643
+ "outputs_caption": outputs_caption,
644
+ "outputs_captionting": outputs_captionting,
645
+ }
646
+ return results
647
+
648
+ @torch.jit.unused
649
+ def _set_aux_loss(self, outputs_class, outputs_seg_masks, outputs_boxes, outputs_captions):
650
+ # this is a workaround to make torchscript happy, as torchscript
651
+ # doesn't support dictionary with non-homogeneous values, such
652
+ # as a dict having both a Tensor and a list.
653
+ if self.mask_classification:
654
+ return [
655
+ {"pred_logits": a, "pred_masks": b, "pred_boxes": c, "pred_captions": d}
656
+ for a, b, c, d in zip(outputs_class[:-1], outputs_seg_masks[:-1], outputs_boxes[:-1], outputs_captions[:-1])
657
+ ]
658
+ else:
659
+ return [{"pred_masks": b} for b in outputs_seg_masks[:-1]]
660
+
661
+
662
+ @register_decoder
663
+ def get_masked_transformer_decoder(cfg, in_channels, lang_encoder, mask_classification, extra):
664
+ return MultiScaleMaskedTransformerDecoder(cfg, in_channels, lang_encoder, mask_classification, extra)
xdecoder/body/decoder/xdecoder.py ADDED
@@ -0,0 +1,700 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/detr.py
3
+
4
+ # --------------------------------------------------------
5
+ # X-Decoder -- Generalized Decoding for Pixel, Image, and Language
6
+ # Copyright (c) 2022 Microsoft
7
+ # Licensed under The MIT License [see LICENSE for details]
8
+ # Written by Xueyan Zou ([email protected]), Jianwei Yang ([email protected])
9
+ # --------------------------------------------------------
10
+
11
+
12
+ import logging
13
+ from typing import Optional
14
+
15
+ import torch
16
+ from torch import nn, Tensor
17
+ from torch.nn import functional as F
18
+
19
+ from timm.models.layers import trunc_normal_
20
+ from detectron2.layers import Conv2d
21
+ import fvcore.nn.weight_init as weight_init
22
+
23
+ from .registry import register_decoder
24
+ from ...utils import configurable
25
+ from ...modules import PositionEmbeddingSine
26
+
27
+
28
+ class SelfAttentionLayer(nn.Module):
29
+
30
+ def __init__(self, d_model, nhead, dropout=0.0,
31
+ activation="relu", normalize_before=False):
32
+ super().__init__()
33
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
34
+
35
+ self.norm = nn.LayerNorm(d_model)
36
+ self.dropout = nn.Dropout(dropout)
37
+
38
+ self.activation = _get_activation_fn(activation)
39
+ self.normalize_before = normalize_before
40
+
41
+ self._reset_parameters()
42
+
43
+ def _reset_parameters(self):
44
+ for p in self.parameters():
45
+ if p.dim() > 1:
46
+ nn.init.xavier_uniform_(p)
47
+
48
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
49
+ return tensor if pos is None else tensor + pos
50
+
51
+ def forward_post(self, tgt,
52
+ tgt_mask: Optional[Tensor] = None,
53
+ tgt_key_padding_mask: Optional[Tensor] = None,
54
+ query_pos: Optional[Tensor] = None):
55
+ q = k = self.with_pos_embed(tgt, query_pos)
56
+ tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,
57
+ key_padding_mask=tgt_key_padding_mask)[0]
58
+ tgt = tgt + self.dropout(tgt2)
59
+ tgt = self.norm(tgt)
60
+
61
+ return tgt
62
+
63
+ def forward_pre(self, tgt,
64
+ tgt_mask: Optional[Tensor] = None,
65
+ tgt_key_padding_mask: Optional[Tensor] = None,
66
+ query_pos: Optional[Tensor] = None):
67
+ tgt2 = self.norm(tgt)
68
+ q = k = self.with_pos_embed(tgt2, query_pos)
69
+ tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
70
+ key_padding_mask=tgt_key_padding_mask)[0]
71
+ tgt = tgt + self.dropout(tgt2)
72
+
73
+ return tgt
74
+
75
+ def forward(self, tgt,
76
+ tgt_mask: Optional[Tensor] = None,
77
+ tgt_key_padding_mask: Optional[Tensor] = None,
78
+ query_pos: Optional[Tensor] = None):
79
+ if self.normalize_before:
80
+ return self.forward_pre(tgt, tgt_mask,
81
+ tgt_key_padding_mask, query_pos)
82
+ return self.forward_post(tgt, tgt_mask,
83
+ tgt_key_padding_mask, query_pos)
84
+
85
+
86
+ class CrossAttentionLayer(nn.Module):
87
+
88
+ def __init__(self, d_model, nhead, dropout=0.0,
89
+ activation="relu", normalize_before=False):
90
+ super().__init__()
91
+ self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
92
+
93
+ self.norm = nn.LayerNorm(d_model)
94
+ self.dropout = nn.Dropout(dropout)
95
+
96
+ self.activation = _get_activation_fn(activation)
97
+ self.normalize_before = normalize_before
98
+
99
+ self._reset_parameters()
100
+
101
+ def _reset_parameters(self):
102
+ for p in self.parameters():
103
+ if p.dim() > 1:
104
+ nn.init.xavier_uniform_(p)
105
+
106
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
107
+ return tensor if pos is None else tensor + pos
108
+
109
+ def forward_post(self, tgt, memory,
110
+ memory_mask: Optional[Tensor] = None,
111
+ memory_key_padding_mask: Optional[Tensor] = None,
112
+ pos: Optional[Tensor] = None,
113
+ query_pos: Optional[Tensor] = None):
114
+ tgt2, avg_attn = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
115
+ key=self.with_pos_embed(memory, pos),
116
+ value=memory, attn_mask=memory_mask,
117
+ key_padding_mask=memory_key_padding_mask)
118
+ tgt = tgt + self.dropout(tgt2)
119
+ tgt = self.norm(tgt)
120
+ return tgt, avg_attn
121
+
122
+ def forward_pre(self, tgt, memory,
123
+ memory_mask: Optional[Tensor] = None,
124
+ memory_key_padding_mask: Optional[Tensor] = None,
125
+ pos: Optional[Tensor] = None,
126
+ query_pos: Optional[Tensor] = None):
127
+ tgt2 = self.norm(tgt)
128
+ tgt2, avg_attn = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
129
+ key=self.with_pos_embed(memory, pos),
130
+ value=memory, attn_mask=memory_mask,
131
+ key_padding_mask=memory_key_padding_mask)
132
+ tgt = tgt + self.dropout(tgt2)
133
+
134
+ return tgt, avg_attn
135
+
136
+ def forward(self, tgt, memory,
137
+ memory_mask: Optional[Tensor] = None,
138
+ memory_key_padding_mask: Optional[Tensor] = None,
139
+ pos: Optional[Tensor] = None,
140
+ query_pos: Optional[Tensor] = None):
141
+ if self.normalize_before:
142
+ return self.forward_pre(tgt, memory, memory_mask,
143
+ memory_key_padding_mask, pos, query_pos)
144
+ return self.forward_post(tgt, memory, memory_mask,
145
+ memory_key_padding_mask, pos, query_pos)
146
+
147
+
148
+ class FFNLayer(nn.Module):
149
+
150
+ def __init__(self, d_model, dim_feedforward=2048, dropout=0.0,
151
+ activation="relu", normalize_before=False):
152
+ super().__init__()
153
+ # Implementation of Feedforward model
154
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
155
+ self.dropout = nn.Dropout(dropout)
156
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
157
+
158
+ self.norm = nn.LayerNorm(d_model)
159
+
160
+ self.activation = _get_activation_fn(activation)
161
+ self.normalize_before = normalize_before
162
+
163
+ self._reset_parameters()
164
+
165
+ def _reset_parameters(self):
166
+ for p in self.parameters():
167
+ if p.dim() > 1:
168
+ nn.init.xavier_uniform_(p)
169
+
170
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
171
+ return tensor if pos is None else tensor + pos
172
+
173
+ def forward_post(self, tgt):
174
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
175
+ tgt = tgt + self.dropout(tgt2)
176
+ tgt = self.norm(tgt)
177
+ return tgt
178
+
179
+ def forward_pre(self, tgt):
180
+ tgt2 = self.norm(tgt)
181
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
182
+ tgt = tgt + self.dropout(tgt2)
183
+ return tgt
184
+
185
+ def forward(self, tgt):
186
+ if self.normalize_before:
187
+ return self.forward_pre(tgt)
188
+ return self.forward_post(tgt)
189
+
190
+
191
+ def _get_activation_fn(activation):
192
+ """Return an activation function given a string"""
193
+ if activation == "relu":
194
+ return F.relu
195
+ if activation == "gelu":
196
+ return F.gelu
197
+ if activation == "glu":
198
+ return F.glu
199
+ raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
200
+
201
+
202
+ class MLP(nn.Module):
203
+ """ Very simple multi-layer perceptron (also called FFN)"""
204
+
205
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
206
+ super().__init__()
207
+ self.num_layers = num_layers
208
+ h = [hidden_dim] * (num_layers - 1)
209
+ self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
210
+
211
+ def forward(self, x):
212
+ for i, layer in enumerate(self.layers):
213
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
214
+ return x
215
+
216
+
217
+ class MultiScaleMaskedTransformerDecoder(nn.Module):
218
+
219
+ _version = 2
220
+
221
+ @configurable
222
+ def __init__(
223
+ self,
224
+ lang_encoder: nn.Module,
225
+ in_channels,
226
+ mask_classification=True,
227
+ *,
228
+ hidden_dim: int,
229
+ dim_proj: int,
230
+ num_queries: int,
231
+ contxt_len: int,
232
+ nheads: int,
233
+ dim_feedforward: int,
234
+ dec_layers: int,
235
+ pre_norm: bool,
236
+ mask_dim: int,
237
+ task_switch: dict,
238
+ captioning_step: int,
239
+ enforce_input_project: bool,
240
+ ):
241
+ """
242
+ NOTE: this interface is experimental.
243
+ Args:
244
+ in_channels: channels of the input features
245
+ mask_classification: whether to add mask classifier or not
246
+ num_classes: number of classes
247
+ hidden_dim: Transformer feature dimension
248
+ num_queries: number of queries
249
+ nheads: number of heads
250
+ dim_feedforward: feature dimension in feedforward network
251
+ enc_layers: number of Transformer encoder layers
252
+ dec_layers: number of Transformer decoder layers
253
+ pre_norm: whether to use pre-LayerNorm or not
254
+ mask_dim: mask feature dimension
255
+ enforce_input_project: add input project 1x1 conv even if input
256
+ channels and hidden dim is identical
257
+ """
258
+ super().__init__()
259
+ assert mask_classification, "Only support mask classification model"
260
+ self.mask_classification = mask_classification
261
+
262
+ # positional encoding
263
+ N_steps = hidden_dim // 2
264
+ self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True)
265
+
266
+ # define Transformer decoder here
267
+ self.num_heads = nheads
268
+ self.num_layers = dec_layers
269
+ self.contxt_len = contxt_len
270
+ self.transformer_self_attention_layers = nn.ModuleList()
271
+ self.transformer_cross_attention_layers = nn.ModuleList()
272
+ self.transformer_ffn_layers = nn.ModuleList()
273
+
274
+ for _ in range(self.num_layers):
275
+ self.transformer_self_attention_layers.append(
276
+ SelfAttentionLayer(
277
+ d_model=hidden_dim,
278
+ nhead=nheads,
279
+ dropout=0.0,
280
+ normalize_before=pre_norm,
281
+ )
282
+ )
283
+
284
+ self.transformer_cross_attention_layers.append(
285
+ CrossAttentionLayer(
286
+ d_model=hidden_dim,
287
+ nhead=nheads,
288
+ dropout=0.0,
289
+ normalize_before=pre_norm,
290
+ )
291
+ )
292
+
293
+ self.transformer_ffn_layers.append(
294
+ FFNLayer(
295
+ d_model=hidden_dim,
296
+ dim_feedforward=dim_feedforward,
297
+ dropout=0.0,
298
+ normalize_before=pre_norm,
299
+ )
300
+ )
301
+
302
+ self.decoder_norm = nn.LayerNorm(hidden_dim)
303
+
304
+ self.num_queries = num_queries
305
+ # learnable query features
306
+ self.query_feat = nn.Embedding(num_queries, hidden_dim)
307
+ # learnable query p.e.
308
+ self.query_embed = nn.Embedding(num_queries, hidden_dim)
309
+
310
+ # level embedding (we always use 3 scales)
311
+ self.num_feature_levels = 3
312
+ self.level_embed = nn.Embedding(self.num_feature_levels, hidden_dim)
313
+ self.input_proj = nn.ModuleList()
314
+
315
+ for _ in range(self.num_feature_levels):
316
+ if in_channels != hidden_dim or enforce_input_project:
317
+ self.input_proj.append(Conv2d(in_channels, hidden_dim, kernel_size=1))
318
+ weight_init.c2_xavier_fill(self.input_proj[-1])
319
+ else:
320
+ self.input_proj.append(nn.Sequential())
321
+
322
+ self.task_switch = task_switch
323
+
324
+ # output FFNs
325
+ self.lang_encoder = lang_encoder
326
+ if self.task_switch['mask']:
327
+ self.mask_embed = MLP(hidden_dim, hidden_dim, mask_dim, 3)
328
+
329
+ self.class_embed = nn.Parameter(torch.empty(hidden_dim, dim_proj))
330
+ trunc_normal_(self.class_embed, std=.02)
331
+
332
+ if task_switch['bbox']:
333
+ self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)
334
+
335
+ # Caption Project and query
336
+ if task_switch['captioning']:
337
+ self.caping_embed = nn.Parameter(torch.empty(hidden_dim, dim_proj))
338
+ trunc_normal_(self.caping_embed, std=.02)
339
+ # self.query_feat_caping = nn.Embedding(contxt_len, hidden_dim)
340
+ self.pos_embed_caping = nn.Embedding(contxt_len, hidden_dim)
341
+ self.captioning_step = captioning_step
342
+
343
+ # register self_attn_mask to avoid information leakage, it includes interaction between object query, class query and caping query
344
+ self_attn_mask = torch.zeros((1, num_queries + contxt_len, num_queries + contxt_len)).bool()
345
+ self_attn_mask[:, :num_queries, num_queries:] = True # object+class query does not attend with caption query.
346
+ self_attn_mask[:, num_queries:, num_queries:] = torch.triu(torch.ones((1, contxt_len, contxt_len)), diagonal=1).bool() # caption query only attend with previous token.
347
+ self_attn_mask[:, :num_queries-1, num_queries-1:num_queries] = True # object query does not attend with class query.
348
+ self_attn_mask[:, num_queries-1:num_queries, :num_queries-1] = True # class query does not attend with object query.
349
+ self.register_buffer("self_attn_mask", self_attn_mask)
350
+
351
+
352
+ @classmethod
353
+ def from_config(cls, cfg, in_channels, lang_encoder, mask_classification, extra):
354
+ ret = {}
355
+
356
+ ret["lang_encoder"] = lang_encoder
357
+ ret["in_channels"] = in_channels
358
+ ret["mask_classification"] = mask_classification
359
+
360
+ enc_cfg = cfg['MODEL']['ENCODER']
361
+ dec_cfg = cfg['MODEL']['DECODER']
362
+
363
+ ret["hidden_dim"] = dec_cfg['HIDDEN_DIM']
364
+ ret["dim_proj"] = cfg['MODEL']['DIM_PROJ']
365
+ ret["num_queries"] = dec_cfg['NUM_OBJECT_QUERIES']
366
+ ret["contxt_len"] = cfg['MODEL']['TEXT']['CONTEXT_LENGTH']
367
+
368
+ # Transformer parameters:
369
+ ret["nheads"] = dec_cfg['NHEADS']
370
+ ret["dim_feedforward"] = dec_cfg['DIM_FEEDFORWARD']
371
+
372
+ # NOTE: because we add learnable query features which requires supervision,
373
+ # we add minus 1 to decoder layers to be consistent with our loss
374
+ # implementation: that is, number of auxiliary losses is always
375
+ # equal to number of decoder layers. With learnable query features, the number of
376
+ # auxiliary losses equals number of decoders plus 1.
377
+ assert dec_cfg['DEC_LAYERS'] >= 1
378
+ ret["dec_layers"] = dec_cfg['DEC_LAYERS'] - 1
379
+ ret["pre_norm"] = dec_cfg['PRE_NORM']
380
+ ret["enforce_input_project"] = dec_cfg['ENFORCE_INPUT_PROJ']
381
+ ret["mask_dim"] = enc_cfg['MASK_DIM']
382
+
383
+ ret["task_switch"] = extra['task_switch']
384
+ ret["captioning_step"] = dec_cfg['CAPTIONING'].get('STEP', 50)
385
+
386
+ return ret
387
+
388
+ def forward(self, x, mask_features, mask=None, target_queries=None, target_vlp=None, task='seg', extra={}):
389
+ if task == 'captioning_infer':
390
+ return self.forward_captioning(x, mask_features, mask=mask, target_queries=target_queries, target_vlp=target_vlp, task=task, extra=extra)
391
+ # x is a list of multi-scale feature
392
+ assert len(x) == self.num_feature_levels
393
+ src = []
394
+ pos = []
395
+ size_list = []
396
+
397
+ # disable mask, it does not affect performance
398
+ del mask
399
+ for i in range(self.num_feature_levels):
400
+ size_list.append(x[i].shape[-2:])
401
+ pos.append(self.pe_layer(x[i], None).flatten(2))
402
+ src.append(self.input_proj[i](x[i]).flatten(2) + self.level_embed.weight[i][None, :, None])
403
+
404
+ # flatten NxCxHxW to HWxNxC
405
+ pos[-1] = pos[-1].permute(2, 0, 1)
406
+ src[-1] = src[-1].permute(2, 0, 1)
407
+
408
+ _, bs, _ = src[0].shape
409
+
410
+ # QxNxC
411
+ query_embed = self.query_embed.weight.unsqueeze(1).repeat(1, bs, 1)
412
+ output = self.query_feat.weight.unsqueeze(1).repeat(1, bs, 1)
413
+
414
+ predictions_class = []
415
+ predictions_mask = []
416
+ predictions_bbox = []
417
+ predictions_caption = []
418
+ predictions_captioning = []
419
+
420
+ self_tgt_mask = None
421
+ if self.training and task == 'vlp' and self.task_switch['captioning']:
422
+ # output = torch.cat((output, self.query_feat_caping.weight.unsqueeze(1).repeat(1, bs, 1)), dim=0) # concat object query, class token and caption token.
423
+ caping_lang_embed = torch.cat([caption['caption_tokens'] for caption in target_vlp], dim=0).transpose(0, 1) # language output
424
+ _caping_lang_embed = caping_lang_embed.detach().clone()
425
+ output = torch.cat((output, _caping_lang_embed), dim=0) # concat object query, class token and caption token.
426
+ caping_lang_embed += self.pos_embed_caping.weight.unsqueeze(1).repeat(1, bs, 1)
427
+ query_embed = torch.cat((query_embed, caping_lang_embed), dim=0) # may not add at the beginning.
428
+ self_tgt_mask = self.self_attn_mask.repeat(output.shape[1]*self.num_heads, 1, 1)
429
+ elif (((self.training and task == 'seg') or (task == 'grounding_eval')) and self.task_switch['grounding']) \
430
+ or ((self.training and task == 'openimage') and self.task_switch['openimage']['grounding']):
431
+ self_tgt_mask = self.self_attn_mask[:,:self.num_queries,:self.num_queries].repeat(output.shape[1]*self.num_heads, 1, 1)
432
+ grounding_tokens = extra['grounding_tokens']
433
+ _grounding_tokens = grounding_tokens.detach().clone()
434
+ # initialize with negative attention at the beginning.
435
+ pad_tgt_mask = torch.ones((1, self.num_queries + (self.num_queries-1) + len(grounding_tokens), self.num_queries + (self.num_queries-1) + len(grounding_tokens)), device=self_tgt_mask.device).bool().repeat(output.shape[1]*self.num_heads, 1, 1)
436
+ pad_tgt_mask[:,:self.num_queries,:self.num_queries] = self_tgt_mask
437
+ pad_tgt_mask[:,self.num_queries:,self.num_queries:] = False # grounding tokens could attend with eatch other
438
+ self_tgt_mask = pad_tgt_mask
439
+ output = torch.cat((output, output[:-1]), dim=0)
440
+ query_embed = torch.cat((query_embed, query_embed[:-1]), dim=0) # also pad language embdding to fix embedding
441
+ else:
442
+ self_tgt_mask = self.self_attn_mask[:,:self.num_queries,:self.num_queries].repeat(output.shape[1]*self.num_heads, 1, 1)
443
+
444
+ # prediction heads on learnable query features
445
+ results = self.forward_prediction_heads(output, mask_features, attn_mask_target_size=size_list[0], task=task)
446
+ attn_mask = results["attn_mask"]
447
+ predictions_class.append(results["outputs_class"])
448
+ predictions_mask.append(results["outputs_mask"])
449
+ predictions_bbox.append(results["outputs_bbox"])
450
+ predictions_caption.append(results["outputs_caption"])
451
+ predictions_captioning.append(results["outputs_captionting"])
452
+
453
+ for i in range(self.num_layers):
454
+ level_index = i % self.num_feature_levels
455
+ attn_mask[torch.where(attn_mask.sum(-1) == attn_mask.shape[-1])] = False
456
+
457
+ if self.training and task == 'vlp' and self.task_switch['captioning']:
458
+ attn_mask = torch.cat((attn_mask, torch.zeros_like(attn_mask[:, :self.contxt_len, :])), dim=1)
459
+ # attention: cross-attention first
460
+ output, avg_attn = self.transformer_cross_attention_layers[i](
461
+ output, src[level_index],
462
+ memory_mask=attn_mask,
463
+ memory_key_padding_mask=None, # here we do not apply masking on padded region
464
+ pos=pos[level_index], query_pos=query_embed
465
+ )
466
+
467
+ if (((self.training and task == 'seg') or (task == 'grounding_eval')) and self.task_switch['grounding']) \
468
+ or ((self.training and task == 'openimage') and self.task_switch['openimage']['grounding']):
469
+ output = torch.cat((output, _grounding_tokens), dim=0)
470
+ query_embed = torch.cat((query_embed, grounding_tokens), dim=0)
471
+
472
+ output = self.transformer_self_attention_layers[i](
473
+ output, tgt_mask=self_tgt_mask,
474
+ tgt_key_padding_mask=None,
475
+ query_pos=query_embed
476
+ )
477
+
478
+ # FFN
479
+ output = self.transformer_ffn_layers[i](
480
+ output
481
+ )
482
+
483
+ if ((self.training and task == 'seg') or (task == 'grounding_eval')) and self.task_switch['grounding'] \
484
+ or ((self.training and task == 'openimage') and self.task_switch['openimage']['grounding']):
485
+ _grounding_tokens = output[-len(_grounding_tokens):]
486
+ output = output[:-len(_grounding_tokens)]
487
+ query_embed = query_embed[:-len(_grounding_tokens)]
488
+
489
+ results = self.forward_prediction_heads(output, mask_features, attn_mask_target_size=size_list[(i + 1) % self.num_feature_levels], layer_id=i, task=task)
490
+ attn_mask = results["attn_mask"]
491
+ predictions_class.append(results["outputs_class"])
492
+ predictions_mask.append(results["outputs_mask"])
493
+ predictions_bbox.append(results["outputs_bbox"])
494
+ predictions_caption.append(results["outputs_caption"])
495
+ predictions_captioning.append(results["outputs_captionting"])
496
+
497
+ assert len(predictions_class) == self.num_layers + 1
498
+ if task == 'vlp':
499
+ out = {'pred_captionings': predictions_captioning[-1],
500
+ 'pred_captions': predictions_caption[-1],
501
+ 'aux_outputs': [{'pred_captionings': x, 'pred_captions': y } for x, y in zip(predictions_captioning[:-1], predictions_caption[:-1])]}
502
+ return out
503
+ else:
504
+ out = {
505
+ 'pred_logits': predictions_class[-1],
506
+ 'pred_masks': predictions_mask[-1],
507
+ 'pred_boxes': predictions_bbox[-1],
508
+ 'pred_captions': predictions_caption[-1],
509
+ 'aux_outputs': self._set_aux_loss(
510
+ predictions_class if self.mask_classification else None, predictions_mask, predictions_bbox, predictions_caption
511
+ )
512
+ }
513
+ return out
514
+
515
+ def forward_captioning(self, x, mask_features, mask = None, target_queries = None, target_vlp = None, task='seg', extra={}):
516
+ # x is a list of multi-scale feature
517
+ assert len(x) == self.num_feature_levels
518
+ src = []
519
+ pos = []
520
+ size_list = []
521
+
522
+ # disable mask, it does not affect performance
523
+ del mask
524
+ for i in range(self.num_feature_levels):
525
+ size_list.append(x[i].shape[-2:])
526
+ pos.append(self.pe_layer(x[i], None).flatten(2))
527
+ src.append(self.input_proj[i](x[i]).flatten(2) + self.level_embed.weight[i][None, :, None])
528
+
529
+ # flatten NxCxHxW to HWxNxC
530
+ pos[-1] = pos[-1].permute(2, 0, 1)
531
+ src[-1] = src[-1].permute(2, 0, 1)
532
+
533
+ _, bs, _ = src[0].shape
534
+
535
+ # QxNxC
536
+ query_embed_ = self.query_embed.weight.unsqueeze(1).repeat(1, bs, 1)
537
+ query_feat = self.query_feat.weight.unsqueeze(1).repeat(1, bs, 1)
538
+ caping_lang_token = extra['start_token'].repeat(bs, 1)
539
+ start_id = 0
540
+ if 'token' in extra:
541
+ caping_lang_token[:,:len(extra['token'][0])] = extra['token']
542
+ start_id = len(extra['token'][0])-1
543
+ # query_feat_caping = self.query_feat_caping.weight.unsqueeze(1).repeat(1, bs, 1)
544
+ pos_embed_caping = self.pos_embed_caping.weight.unsqueeze(1).repeat(1, bs, 1)
545
+ # prepare token embedding for evaluation
546
+ token_embs = self.lang_encoder.lang_encoder.token_embedding.weight
547
+ # token_embs = (token_embs / token_embs.norm(dim=-1, keepdim=True) + 1e-7)
548
+
549
+ for cap_idx in range(start_id, self.captioning_step):
550
+ caping_lang_embed = self.lang_encoder.forward_language_token((caping_lang_token,))[0].transpose(0, 1)
551
+ output = torch.cat((query_feat, caping_lang_embed), dim=0) # concat object query, class token and caption token.
552
+ caping_lang_embed += pos_embed_caping
553
+ query_embed = torch.cat((query_embed_, caping_lang_embed), dim=0) # may not add at the beginning.
554
+ # output = torch.cat((query_feat, query_feat_caping), dim=0) # concat object query, class token and caption token.
555
+
556
+ # prediction heads on learnable query features
557
+ results = self.forward_prediction_heads(output, mask_features, attn_mask_target_size=size_list[0], task=task)
558
+ attn_mask = results["attn_mask"]
559
+
560
+ for i in range(self.num_layers):
561
+ level_index = i % self.num_feature_levels
562
+ attn_mask[torch.where(attn_mask.sum(-1) == attn_mask.shape[-1])] = False
563
+ attn_mask = torch.cat((attn_mask, torch.zeros_like(attn_mask[:, :self.contxt_len, :])), dim=1)
564
+ self_tgt_mask = self.self_attn_mask.repeat(output.shape[1]*self.num_heads, 1, 1)
565
+
566
+ if extra['captioning_mask'] is not None:
567
+ bs,nq,wh = attn_mask.shape
568
+ assert bs==self.num_heads, "Only support single image referring captioning."
569
+ cap_mask = extra['captioning_mask']
570
+ attn_mask = attn_mask.reshape(bs,nq,size_list[i%3][0],size_list[i%3][1])
571
+ cap_mask = F.interpolate(cap_mask[None,].float(), size_list[i%3], mode='nearest').bool()[0,0]
572
+ attn_mask[:,self.num_queries:, cap_mask] = True
573
+ attn_mask = attn_mask.reshape(bs,nq,wh)
574
+
575
+ # attention: cross-attention first
576
+ output, avg_attn = self.transformer_cross_attention_layers[i](
577
+ output, src[level_index],
578
+ memory_mask=attn_mask,
579
+ memory_key_padding_mask=None, # here we do not apply masking on padded region
580
+ pos=pos[level_index], query_pos=query_embed
581
+ )
582
+
583
+ output = self.transformer_self_attention_layers[i](
584
+ output, tgt_mask=self_tgt_mask,
585
+ tgt_key_padding_mask=None,
586
+ query_pos=query_embed
587
+ )
588
+
589
+ # FFN
590
+ output = self.transformer_ffn_layers[i](
591
+ output
592
+ )
593
+
594
+ results = self.forward_prediction_heads(output, mask_features, attn_mask_target_size=size_list[(i + 1) % self.num_feature_levels], layer_id=i, task=task)
595
+ attn_mask = results["attn_mask"]
596
+
597
+ pred_captions_gen = results['outputs_captionting']
598
+ # pred_captions_gen = (pred_captions_gen / pred_captions_gen.norm(dim=-1, keepdim=True) + 1e-7)
599
+ pred_captions_gen = pred_captions_gen @ token_embs.t()
600
+ caping_lang_token[:,cap_idx+1] = pred_captions_gen[:,cap_idx].max(-1)[1]
601
+
602
+ texts = self.lang_encoder.tokenizer.batch_decode(caping_lang_token, skip_special_tokens=False)
603
+ texts_new = []
604
+
605
+ for x in texts:
606
+ x = x.split('<|endoftext|>')[0]
607
+ x = x.replace('<|endoftext|>','')
608
+ x = x.replace('<|startoftext|>','')
609
+ x = x.strip()
610
+ texts_new.append(x)
611
+
612
+ out = {'pred_captionings': caping_lang_token,
613
+ 'pred_texts': texts_new}
614
+ return out
615
+
616
+
617
+ def forward_prediction_heads(self, output, mask_features, attn_mask_target_size, layer_id=-1, task='seg'):
618
+ decoder_output = self.decoder_norm(output)
619
+ decoder_output = decoder_output.transpose(0, 1)
620
+
621
+ # extract image captioning token from decoder output.
622
+ if self.task_switch['captioning'] and (task == 'vlp' or task == 'captioning_infer'):
623
+ outputs_captionting = decoder_output[:,self.num_queries:] @ self.caping_embed
624
+ else:
625
+ outputs_captionting = None
626
+
627
+ # recompute class token output.
628
+ norm_decoder_output = decoder_output / (decoder_output.norm(dim=-1, keepdim=True) + 1e-7)
629
+ obj_token = norm_decoder_output[:,:self.num_queries-1]
630
+ cls_token = norm_decoder_output[:,self.num_queries-1:self.num_queries]
631
+
632
+ sim = (cls_token @ obj_token.transpose(1,2)).softmax(-1)[:,0,:,None] # TODO include class token.
633
+ cls_token = (sim * decoder_output[:,:self.num_queries-1]).sum(dim=1, keepdim=True)
634
+
635
+ if (((self.training and task == 'seg') or (task == 'grounding_eval')) and self.task_switch['grounding']) \
636
+ or ((self.training and task == 'openimage') and self.task_switch['openimage']['grounding']):
637
+ decoder_output = torch.cat((decoder_output[:,:self.num_queries-1], cls_token, decoder_output[:,self.num_queries:2*self.num_queries-1]), dim=1)
638
+ else:
639
+ decoder_output = torch.cat((decoder_output[:,:self.num_queries-1], cls_token), dim=1)
640
+
641
+ # compute class, mask and bbox.
642
+ class_embed = decoder_output @ self.class_embed
643
+ # HACK do not compute similarity if mask is not on
644
+ outputs_class = self.lang_encoder.compute_similarity(class_embed, fake=(((not self.task_switch['mask']) and self.training) or (task == 'openimage')))
645
+
646
+ if self.task_switch['mask'] or self.task_switch['openimage']['mask']:
647
+ mask_embed = self.mask_embed(decoder_output)
648
+ outputs_mask = torch.einsum("bqc,bchw->bqhw", mask_embed, mask_features)
649
+
650
+ # NOTE: prediction is of higher-resolution
651
+ # [B, Q, H, W] -> [B, Q, H*W] -> [B, h, Q, H*W] -> [B*h, Q, HW]
652
+ attn_mask = F.interpolate(outputs_mask, size=attn_mask_target_size, mode="bilinear", align_corners=False)
653
+
654
+ # must use bool type
655
+ # If a BoolTensor is provided, positions with ``True`` are not allowed to attend while ``False`` values will be unchanged.
656
+ attn_mask = (attn_mask.sigmoid().flatten(2).unsqueeze(1).repeat(1, self.num_heads, 1, 1).flatten(0, 1) < 0.5).bool()
657
+ attn_mask = attn_mask.detach()
658
+
659
+ # NOTE: fill False for cls token (JY)
660
+ attn_mask[:, self.num_queries:self.num_queries+1].fill_(False)
661
+ else:
662
+ outputs_mask = None
663
+ attn_mask = torch.zeros((list(decoder_output.shape[:2]) + [attn_mask_target_size[0]*attn_mask_target_size[1]]), device=decoder_output.device).repeat(self.num_heads, 1, 1).bool()
664
+
665
+ outputs_bbox = [None for i in range(len(decoder_output))]
666
+ if self.task_switch['bbox']:
667
+ outputs_bbox = self.bbox_embed(decoder_output)
668
+
669
+ outputs_caption = None
670
+ if self.task_switch['caption']:
671
+ outputs_caption = class_embed
672
+
673
+
674
+ results = {
675
+ "outputs_class": outputs_class,
676
+ "outputs_mask": outputs_mask,
677
+ "outputs_bbox": outputs_bbox,
678
+ "attn_mask": attn_mask,
679
+ "outputs_caption": outputs_caption,
680
+ "outputs_captionting": outputs_captionting,
681
+ }
682
+ return results
683
+
684
+ @torch.jit.unused
685
+ def _set_aux_loss(self, outputs_class, outputs_seg_masks, outputs_boxes, outputs_captions):
686
+ # this is a workaround to make torchscript happy, as torchscript
687
+ # doesn't support dictionary with non-homogeneous values, such
688
+ # as a dict having both a Tensor and a list.
689
+ if self.mask_classification:
690
+ return [
691
+ {"pred_logits": a, "pred_masks": b, "pred_boxes": c, "pred_captions": d}
692
+ for a, b, c, d in zip(outputs_class[:-1], outputs_seg_masks[:-1], outputs_boxes[:-1], outputs_captions[:-1])
693
+ ]
694
+ else:
695
+ return [{"pred_masks": b} for b in outputs_seg_masks[:-1]]
696
+
697
+
698
+ @register_decoder
699
+ def get_masked_transformer_decoder(cfg, in_channels, lang_encoder, mask_classification, extra):
700
+ return MultiScaleMaskedTransformerDecoder(cfg, in_channels, lang_encoder, mask_classification, extra)
xdecoder/body/decoder/xdecoder2.py ADDED
@@ -0,0 +1,700 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/detr.py
3
+
4
+ # --------------------------------------------------------
5
+ # X-Decoder -- Generalized Decoding for Pixel, Image, and Language
6
+ # Copyright (c) 2022 Microsoft
7
+ # Licensed under The MIT License [see LICENSE for details]
8
+ # Written by Xueyan Zou ([email protected]), Jianwei Yang ([email protected])
9
+ # --------------------------------------------------------
10
+
11
+
12
+ import logging
13
+ from typing import Optional
14
+
15
+ import torch
16
+ from torch import nn, Tensor
17
+ from torch.nn import functional as F
18
+
19
+ from timm.models.layers import trunc_normal_
20
+ from detectron2.layers import Conv2d
21
+ import fvcore.nn.weight_init as weight_init
22
+
23
+ from .registry import register_decoder
24
+ from ...utils import configurable
25
+ from ...modules import PositionEmbeddingSine
26
+
27
+
28
+ class SelfAttentionLayer(nn.Module):
29
+
30
+ def __init__(self, d_model, nhead, dropout=0.0,
31
+ activation="relu", normalize_before=False):
32
+ super().__init__()
33
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
34
+
35
+ self.norm = nn.LayerNorm(d_model)
36
+ self.dropout = nn.Dropout(dropout)
37
+
38
+ self.activation = _get_activation_fn(activation)
39
+ self.normalize_before = normalize_before
40
+
41
+ self._reset_parameters()
42
+
43
+ def _reset_parameters(self):
44
+ for p in self.parameters():
45
+ if p.dim() > 1:
46
+ nn.init.xavier_uniform_(p)
47
+
48
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
49
+ return tensor if pos is None else tensor + pos
50
+
51
+ def forward_post(self, tgt,
52
+ tgt_mask: Optional[Tensor] = None,
53
+ tgt_key_padding_mask: Optional[Tensor] = None,
54
+ query_pos: Optional[Tensor] = None):
55
+ q = k = self.with_pos_embed(tgt, query_pos)
56
+ tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,
57
+ key_padding_mask=tgt_key_padding_mask)[0]
58
+ tgt = tgt + self.dropout(tgt2)
59
+ tgt = self.norm(tgt)
60
+
61
+ return tgt
62
+
63
+ def forward_pre(self, tgt,
64
+ tgt_mask: Optional[Tensor] = None,
65
+ tgt_key_padding_mask: Optional[Tensor] = None,
66
+ query_pos: Optional[Tensor] = None):
67
+ tgt2 = self.norm(tgt)
68
+ q = k = self.with_pos_embed(tgt2, query_pos)
69
+ tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
70
+ key_padding_mask=tgt_key_padding_mask)[0]
71
+ tgt = tgt + self.dropout(tgt2)
72
+
73
+ return tgt
74
+
75
+ def forward(self, tgt,
76
+ tgt_mask: Optional[Tensor] = None,
77
+ tgt_key_padding_mask: Optional[Tensor] = None,
78
+ query_pos: Optional[Tensor] = None):
79
+ if self.normalize_before:
80
+ return self.forward_pre(tgt, tgt_mask,
81
+ tgt_key_padding_mask, query_pos)
82
+ return self.forward_post(tgt, tgt_mask,
83
+ tgt_key_padding_mask, query_pos)
84
+
85
+
86
+ class CrossAttentionLayer(nn.Module):
87
+
88
+ def __init__(self, d_model, nhead, dropout=0.0,
89
+ activation="relu", normalize_before=False):
90
+ super().__init__()
91
+ self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
92
+
93
+ self.norm = nn.LayerNorm(d_model)
94
+ self.dropout = nn.Dropout(dropout)
95
+
96
+ self.activation = _get_activation_fn(activation)
97
+ self.normalize_before = normalize_before
98
+
99
+ self._reset_parameters()
100
+
101
+ def _reset_parameters(self):
102
+ for p in self.parameters():
103
+ if p.dim() > 1:
104
+ nn.init.xavier_uniform_(p)
105
+
106
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
107
+ return tensor if pos is None else tensor + pos
108
+
109
+ def forward_post(self, tgt, memory,
110
+ memory_mask: Optional[Tensor] = None,
111
+ memory_key_padding_mask: Optional[Tensor] = None,
112
+ pos: Optional[Tensor] = None,
113
+ query_pos: Optional[Tensor] = None):
114
+ tgt2, avg_attn = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
115
+ key=self.with_pos_embed(memory, pos),
116
+ value=memory, attn_mask=memory_mask,
117
+ key_padding_mask=memory_key_padding_mask)
118
+ tgt = tgt + self.dropout(tgt2)
119
+ tgt = self.norm(tgt)
120
+ return tgt, avg_attn
121
+
122
+ def forward_pre(self, tgt, memory,
123
+ memory_mask: Optional[Tensor] = None,
124
+ memory_key_padding_mask: Optional[Tensor] = None,
125
+ pos: Optional[Tensor] = None,
126
+ query_pos: Optional[Tensor] = None):
127
+ tgt2 = self.norm(tgt)
128
+ tgt2, avg_attn = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
129
+ key=self.with_pos_embed(memory, pos),
130
+ value=memory, attn_mask=memory_mask,
131
+ key_padding_mask=memory_key_padding_mask)
132
+ tgt = tgt + self.dropout(tgt2)
133
+
134
+ return tgt, avg_attn
135
+
136
+ def forward(self, tgt, memory,
137
+ memory_mask: Optional[Tensor] = None,
138
+ memory_key_padding_mask: Optional[Tensor] = None,
139
+ pos: Optional[Tensor] = None,
140
+ query_pos: Optional[Tensor] = None):
141
+ if self.normalize_before:
142
+ return self.forward_pre(tgt, memory, memory_mask,
143
+ memory_key_padding_mask, pos, query_pos)
144
+ return self.forward_post(tgt, memory, memory_mask,
145
+ memory_key_padding_mask, pos, query_pos)
146
+
147
+
148
+ class FFNLayer(nn.Module):
149
+
150
+ def __init__(self, d_model, dim_feedforward=2048, dropout=0.0,
151
+ activation="relu", normalize_before=False):
152
+ super().__init__()
153
+ # Implementation of Feedforward model
154
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
155
+ self.dropout = nn.Dropout(dropout)
156
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
157
+
158
+ self.norm = nn.LayerNorm(d_model)
159
+
160
+ self.activation = _get_activation_fn(activation)
161
+ self.normalize_before = normalize_before
162
+
163
+ self._reset_parameters()
164
+
165
+ def _reset_parameters(self):
166
+ for p in self.parameters():
167
+ if p.dim() > 1:
168
+ nn.init.xavier_uniform_(p)
169
+
170
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
171
+ return tensor if pos is None else tensor + pos
172
+
173
+ def forward_post(self, tgt):
174
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
175
+ tgt = tgt + self.dropout(tgt2)
176
+ tgt = self.norm(tgt)
177
+ return tgt
178
+
179
+ def forward_pre(self, tgt):
180
+ tgt2 = self.norm(tgt)
181
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
182
+ tgt = tgt + self.dropout(tgt2)
183
+ return tgt
184
+
185
+ def forward(self, tgt):
186
+ if self.normalize_before:
187
+ return self.forward_pre(tgt)
188
+ return self.forward_post(tgt)
189
+
190
+
191
+ def _get_activation_fn(activation):
192
+ """Return an activation function given a string"""
193
+ if activation == "relu":
194
+ return F.relu
195
+ if activation == "gelu":
196
+ return F.gelu
197
+ if activation == "glu":
198
+ return F.glu
199
+ raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
200
+
201
+
202
+ class MLP(nn.Module):
203
+ """ Very simple multi-layer perceptron (also called FFN)"""
204
+
205
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
206
+ super().__init__()
207
+ self.num_layers = num_layers
208
+ h = [hidden_dim] * (num_layers - 1)
209
+ self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
210
+
211
+ def forward(self, x):
212
+ for i, layer in enumerate(self.layers):
213
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
214
+ return x
215
+
216
+
217
+ class MultiScaleMaskedTransformerDecoder(nn.Module):
218
+
219
+ _version = 2
220
+
221
+ @configurable
222
+ def __init__(
223
+ self,
224
+ lang_encoder: nn.Module,
225
+ in_channels,
226
+ mask_classification=True,
227
+ *,
228
+ hidden_dim: int,
229
+ dim_proj: int,
230
+ num_queries: int,
231
+ contxt_len: int,
232
+ nheads: int,
233
+ dim_feedforward: int,
234
+ dec_layers: int,
235
+ pre_norm: bool,
236
+ mask_dim: int,
237
+ task_switch: dict,
238
+ captioning_step: int,
239
+ enforce_input_project: bool,
240
+ ):
241
+ """
242
+ NOTE: this interface is experimental.
243
+ Args:
244
+ in_channels: channels of the input features
245
+ mask_classification: whether to add mask classifier or not
246
+ num_classes: number of classes
247
+ hidden_dim: Transformer feature dimension
248
+ num_queries: number of queries
249
+ nheads: number of heads
250
+ dim_feedforward: feature dimension in feedforward network
251
+ enc_layers: number of Transformer encoder layers
252
+ dec_layers: number of Transformer decoder layers
253
+ pre_norm: whether to use pre-LayerNorm or not
254
+ mask_dim: mask feature dimension
255
+ enforce_input_project: add input project 1x1 conv even if input
256
+ channels and hidden dim is identical
257
+ """
258
+ super().__init__()
259
+ assert mask_classification, "Only support mask classification model"
260
+ self.mask_classification = mask_classification
261
+
262
+ # positional encoding
263
+ N_steps = hidden_dim // 2
264
+ self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True)
265
+
266
+ # define Transformer decoder here
267
+ self.num_heads = nheads
268
+ self.num_layers = dec_layers
269
+ self.contxt_len = contxt_len
270
+ self.transformer_self_attention_layers = nn.ModuleList()
271
+ self.transformer_cross_attention_layers = nn.ModuleList()
272
+ self.transformer_ffn_layers = nn.ModuleList()
273
+
274
+ for _ in range(self.num_layers):
275
+ self.transformer_self_attention_layers.append(
276
+ SelfAttentionLayer(
277
+ d_model=hidden_dim,
278
+ nhead=nheads,
279
+ dropout=0.0,
280
+ normalize_before=pre_norm,
281
+ )
282
+ )
283
+
284
+ self.transformer_cross_attention_layers.append(
285
+ CrossAttentionLayer(
286
+ d_model=hidden_dim,
287
+ nhead=nheads,
288
+ dropout=0.0,
289
+ normalize_before=pre_norm,
290
+ )
291
+ )
292
+
293
+ self.transformer_ffn_layers.append(
294
+ FFNLayer(
295
+ d_model=hidden_dim,
296
+ dim_feedforward=dim_feedforward,
297
+ dropout=0.0,
298
+ normalize_before=pre_norm,
299
+ )
300
+ )
301
+
302
+ self.decoder_norm = nn.LayerNorm(hidden_dim)
303
+
304
+ self.num_queries = num_queries
305
+ # learnable query features
306
+ self.query_feat = nn.Embedding(num_queries, hidden_dim)
307
+ # learnable query p.e.
308
+ self.query_embed = nn.Embedding(num_queries, hidden_dim)
309
+
310
+ # level embedding (we always use 3 scales)
311
+ self.num_feature_levels = 3
312
+ self.level_embed = nn.Embedding(self.num_feature_levels, hidden_dim)
313
+ self.input_proj = nn.ModuleList()
314
+
315
+ for _ in range(self.num_feature_levels):
316
+ if in_channels != hidden_dim or enforce_input_project:
317
+ self.input_proj.append(Conv2d(in_channels, hidden_dim, kernel_size=1))
318
+ weight_init.c2_xavier_fill(self.input_proj[-1])
319
+ else:
320
+ self.input_proj.append(nn.Sequential())
321
+
322
+ self.task_switch = task_switch
323
+
324
+ # output FFNs
325
+ self.lang_encoder = lang_encoder
326
+ if self.task_switch['mask']:
327
+ self.mask_embed = MLP(hidden_dim, hidden_dim, mask_dim, 3)
328
+
329
+ self.class_embed = nn.Parameter(torch.empty(hidden_dim, dim_proj))
330
+ trunc_normal_(self.class_embed, std=.02)
331
+
332
+ if task_switch['bbox']:
333
+ self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)
334
+
335
+ # Caption Project and query
336
+ if task_switch['captioning']:
337
+ self.caping_embed = nn.Parameter(torch.empty(hidden_dim, dim_proj))
338
+ trunc_normal_(self.caping_embed, std=.02)
339
+ self.query_feat_caping = nn.Embedding(contxt_len, hidden_dim)
340
+ # self.pos_embed_caping = nn.Embedding(contxt_len, hidden_dim)
341
+ self.captioning_step = captioning_step
342
+
343
+ # register self_attn_mask to avoid information leakage, it includes interaction between object query, class query and caping query
344
+ self_attn_mask = torch.zeros((1, num_queries + contxt_len, num_queries + contxt_len)).bool()
345
+ self_attn_mask[:, :num_queries, num_queries:] = True # object+class query does not attend with caption query.
346
+ self_attn_mask[:, num_queries:, num_queries:] = torch.triu(torch.ones((1, contxt_len, contxt_len)), diagonal=1).bool() # caption query only attend with previous token.
347
+ self_attn_mask[:, :num_queries-1, num_queries-1:num_queries] = True # object query does not attend with class query.
348
+ self_attn_mask[:, num_queries-1:num_queries, :num_queries-1] = True # class query does not attend with object query.
349
+ self.register_buffer("self_attn_mask", self_attn_mask)
350
+
351
+
352
+ @classmethod
353
+ def from_config(cls, cfg, in_channels, lang_encoder, mask_classification, extra):
354
+ ret = {}
355
+
356
+ ret["lang_encoder"] = lang_encoder
357
+ ret["in_channels"] = in_channels
358
+ ret["mask_classification"] = mask_classification
359
+
360
+ enc_cfg = cfg['MODEL']['ENCODER']
361
+ dec_cfg = cfg['MODEL']['DECODER']
362
+
363
+ ret["hidden_dim"] = dec_cfg['HIDDEN_DIM']
364
+ ret["dim_proj"] = cfg['MODEL']['DIM_PROJ']
365
+ ret["num_queries"] = dec_cfg['NUM_OBJECT_QUERIES']
366
+ ret["contxt_len"] = cfg['MODEL']['TEXT']['CONTEXT_LENGTH']
367
+
368
+ # Transformer parameters:
369
+ ret["nheads"] = dec_cfg['NHEADS']
370
+ ret["dim_feedforward"] = dec_cfg['DIM_FEEDFORWARD']
371
+
372
+ # NOTE: because we add learnable query features which requires supervision,
373
+ # we add minus 1 to decoder layers to be consistent with our loss
374
+ # implementation: that is, number of auxiliary losses is always
375
+ # equal to number of decoder layers. With learnable query features, the number of
376
+ # auxiliary losses equals number of decoders plus 1.
377
+ assert dec_cfg['DEC_LAYERS'] >= 1
378
+ ret["dec_layers"] = dec_cfg['DEC_LAYERS'] - 1
379
+ ret["pre_norm"] = dec_cfg['PRE_NORM']
380
+ ret["enforce_input_project"] = dec_cfg['ENFORCE_INPUT_PROJ']
381
+ ret["mask_dim"] = enc_cfg['MASK_DIM']
382
+
383
+ ret["task_switch"] = extra['task_switch']
384
+ ret["captioning_step"] = dec_cfg['CAPTIONING'].get('STEP', 50)
385
+
386
+ return ret
387
+
388
+ def forward(self, x, mask_features, mask=None, target_queries=None, target_vlp=None, task='seg', extra={}):
389
+ if task == 'captioning_infer':
390
+ return self.forward_captioning(x, mask_features, mask=mask, target_queries=target_queries, target_vlp=target_vlp, task=task, extra=extra)
391
+ # x is a list of multi-scale feature
392
+ assert len(x) == self.num_feature_levels
393
+ src = []
394
+ pos = []
395
+ size_list = []
396
+
397
+ # disable mask, it does not affect performance
398
+ del mask
399
+ for i in range(self.num_feature_levels):
400
+ size_list.append(x[i].shape[-2:])
401
+ pos.append(self.pe_layer(x[i], None).flatten(2))
402
+ src.append(self.input_proj[i](x[i]).flatten(2) + self.level_embed.weight[i][None, :, None])
403
+
404
+ # flatten NxCxHxW to HWxNxC
405
+ pos[-1] = pos[-1].permute(2, 0, 1)
406
+ src[-1] = src[-1].permute(2, 0, 1)
407
+
408
+ _, bs, _ = src[0].shape
409
+
410
+ # QxNxC
411
+ query_embed = self.query_embed.weight.unsqueeze(1).repeat(1, bs, 1)
412
+ output = self.query_feat.weight.unsqueeze(1).repeat(1, bs, 1)
413
+
414
+ predictions_class = []
415
+ predictions_mask = []
416
+ predictions_bbox = []
417
+ predictions_caption = []
418
+ predictions_captioning = []
419
+
420
+ self_tgt_mask = None
421
+ if self.training and task == 'vlp' and self.task_switch['captioning']:
422
+ output = torch.cat((output, self.query_feat_caping.weight.unsqueeze(1).repeat(1, bs, 1)), dim=0) # concat object query, class token and caption token.
423
+ caping_lang_embed = torch.cat([caption['caption_tokens'] for caption in target_vlp], dim=0).transpose(0, 1) # language output
424
+ # _caping_lang_embed = caping_lang_embed.detach().clone()
425
+ # output = torch.cat((output, _caping_lang_embed), dim=0) # concat object query, class token and caption token.
426
+ # caping_lang_embed += self.pos_embed_caping.weight.unsqueeze(1).repeat(1, bs, 1)
427
+ query_embed = torch.cat((query_embed, caping_lang_embed), dim=0) # may not add at the beginning.
428
+ self_tgt_mask = self.self_attn_mask.repeat(output.shape[1]*self.num_heads, 1, 1)
429
+ elif (((self.training and task == 'seg') or (task == 'grounding_eval')) and self.task_switch['grounding']) \
430
+ or ((self.training and task == 'openimage') and self.task_switch['openimage']['grounding']):
431
+ self_tgt_mask = self.self_attn_mask[:,:self.num_queries,:self.num_queries].repeat(output.shape[1]*self.num_heads, 1, 1)
432
+ grounding_tokens = extra['grounding_tokens']
433
+ _grounding_tokens = grounding_tokens.detach().clone()
434
+ # initialize with negative attention at the beginning.
435
+ pad_tgt_mask = torch.ones((1, self.num_queries + (self.num_queries-1) + len(grounding_tokens), self.num_queries + (self.num_queries-1) + len(grounding_tokens)), device=self_tgt_mask.device).bool().repeat(output.shape[1]*self.num_heads, 1, 1)
436
+ pad_tgt_mask[:,:self.num_queries,:self.num_queries] = self_tgt_mask
437
+ pad_tgt_mask[:,self.num_queries:,self.num_queries:] = False # grounding tokens could attend with eatch other
438
+ self_tgt_mask = pad_tgt_mask
439
+ output = torch.cat((output, output[:-1]), dim=0)
440
+ query_embed = torch.cat((query_embed, query_embed[:-1]), dim=0) # also pad language embdding to fix embedding
441
+ else:
442
+ self_tgt_mask = self.self_attn_mask[:,:self.num_queries,:self.num_queries].repeat(output.shape[1]*self.num_heads, 1, 1)
443
+
444
+ # prediction heads on learnable query features
445
+ results = self.forward_prediction_heads(output, mask_features, attn_mask_target_size=size_list[0], task=task)
446
+ attn_mask = results["attn_mask"]
447
+ predictions_class.append(results["outputs_class"])
448
+ predictions_mask.append(results["outputs_mask"])
449
+ predictions_bbox.append(results["outputs_bbox"])
450
+ predictions_caption.append(results["outputs_caption"])
451
+ predictions_captioning.append(results["outputs_captionting"])
452
+
453
+ for i in range(self.num_layers):
454
+ level_index = i % self.num_feature_levels
455
+ attn_mask[torch.where(attn_mask.sum(-1) == attn_mask.shape[-1])] = False
456
+
457
+ if self.training and task == 'vlp' and self.task_switch['captioning']:
458
+ attn_mask = torch.cat((attn_mask, torch.zeros_like(attn_mask[:, :self.contxt_len, :])), dim=1)
459
+ # attention: cross-attention first
460
+ output, avg_attn = self.transformer_cross_attention_layers[i](
461
+ output, src[level_index],
462
+ memory_mask=attn_mask,
463
+ memory_key_padding_mask=None, # here we do not apply masking on padded region
464
+ pos=pos[level_index], query_pos=query_embed
465
+ )
466
+
467
+ if (((self.training and task == 'seg') or (task == 'grounding_eval')) and self.task_switch['grounding']) \
468
+ or ((self.training and task == 'openimage') and self.task_switch['openimage']['grounding']):
469
+ output = torch.cat((output, _grounding_tokens), dim=0)
470
+ query_embed = torch.cat((query_embed, grounding_tokens), dim=0)
471
+
472
+ output = self.transformer_self_attention_layers[i](
473
+ output, tgt_mask=self_tgt_mask,
474
+ tgt_key_padding_mask=None,
475
+ query_pos=query_embed
476
+ )
477
+
478
+ # FFN
479
+ output = self.transformer_ffn_layers[i](
480
+ output
481
+ )
482
+
483
+ if ((self.training and task == 'seg') or (task == 'grounding_eval')) and self.task_switch['grounding'] \
484
+ or ((self.training and task == 'openimage') and self.task_switch['openimage']['grounding']):
485
+ _grounding_tokens = output[-len(_grounding_tokens):]
486
+ output = output[:-len(_grounding_tokens)]
487
+ query_embed = query_embed[:-len(_grounding_tokens)]
488
+
489
+ results = self.forward_prediction_heads(output, mask_features, attn_mask_target_size=size_list[(i + 1) % self.num_feature_levels], layer_id=i, task=task)
490
+ attn_mask = results["attn_mask"]
491
+ predictions_class.append(results["outputs_class"])
492
+ predictions_mask.append(results["outputs_mask"])
493
+ predictions_bbox.append(results["outputs_bbox"])
494
+ predictions_caption.append(results["outputs_caption"])
495
+ predictions_captioning.append(results["outputs_captionting"])
496
+
497
+ assert len(predictions_class) == self.num_layers + 1
498
+ if task == 'vlp':
499
+ out = {'pred_captionings': predictions_captioning[-1],
500
+ 'pred_captions': predictions_caption[-1],
501
+ 'aux_outputs': [{'pred_captionings': x, 'pred_captions': y } for x, y in zip(predictions_captioning[:-1], predictions_caption[:-1])]}
502
+ return out
503
+ else:
504
+ out = {
505
+ 'pred_logits': predictions_class[-1],
506
+ 'pred_masks': predictions_mask[-1],
507
+ 'pred_boxes': predictions_bbox[-1],
508
+ 'pred_captions': predictions_caption[-1],
509
+ 'aux_outputs': self._set_aux_loss(
510
+ predictions_class if self.mask_classification else None, predictions_mask, predictions_bbox, predictions_caption
511
+ )
512
+ }
513
+ return out
514
+
515
+ def forward_captioning(self, x, mask_features, mask = None, target_queries = None, target_vlp = None, task='seg', extra={}):
516
+ # x is a list of multi-scale feature
517
+ assert len(x) == self.num_feature_levels
518
+ src = []
519
+ pos = []
520
+ size_list = []
521
+
522
+ # disable mask, it does not affect performance
523
+ del mask
524
+ for i in range(self.num_feature_levels):
525
+ size_list.append(x[i].shape[-2:])
526
+ pos.append(self.pe_layer(x[i], None).flatten(2))
527
+ src.append(self.input_proj[i](x[i]).flatten(2) + self.level_embed.weight[i][None, :, None])
528
+
529
+ # flatten NxCxHxW to HWxNxC
530
+ pos[-1] = pos[-1].permute(2, 0, 1)
531
+ src[-1] = src[-1].permute(2, 0, 1)
532
+
533
+ _, bs, _ = src[0].shape
534
+
535
+ # QxNxC
536
+ query_embed_ = self.query_embed.weight.unsqueeze(1).repeat(1, bs, 1)
537
+ query_feat = self.query_feat.weight.unsqueeze(1).repeat(1, bs, 1)
538
+ caping_lang_token = extra['start_token'].repeat(bs, 1)
539
+ start_id = 0
540
+ if 'token' in extra:
541
+ caping_lang_token[:,:len(extra['token'][0])] = extra['token']
542
+ start_id = len(extra['token'][0])-1
543
+ query_feat_caping = self.query_feat_caping.weight.unsqueeze(1).repeat(1, bs, 1)
544
+ # pos_embed_caping = self.pos_embed_caping.weight.unsqueeze(1).repeat(1, bs, 1)
545
+ # prepare token embedding for evaluation
546
+ token_embs = self.lang_encoder.lang_encoder.token_embedding.weight
547
+ # token_embs = (token_embs / token_embs.norm(dim=-1, keepdim=True) + 1e-7)
548
+
549
+ for cap_idx in range(start_id, self.captioning_step):
550
+ caping_lang_embed = self.lang_encoder.forward_language_token((caping_lang_token,))[0].transpose(0, 1)
551
+ # output = torch.cat((query_feat, caping_lang_embed), dim=0) # concat object query, class token and caption token.
552
+ # caping_lang_embed += pos_embed_caping
553
+ query_embed = torch.cat((query_embed_, caping_lang_embed), dim=0) # may not add at the beginning.
554
+ output = torch.cat((query_feat, query_feat_caping), dim=0) # concat object query, class token and caption token.
555
+
556
+ # prediction heads on learnable query features
557
+ results = self.forward_prediction_heads(output, mask_features, attn_mask_target_size=size_list[0], task=task)
558
+ attn_mask = results["attn_mask"]
559
+
560
+ for i in range(self.num_layers):
561
+ level_index = i % self.num_feature_levels
562
+ attn_mask[torch.where(attn_mask.sum(-1) == attn_mask.shape[-1])] = False
563
+ attn_mask = torch.cat((attn_mask, torch.zeros_like(attn_mask[:, :self.contxt_len, :])), dim=1)
564
+ self_tgt_mask = self.self_attn_mask.repeat(output.shape[1]*self.num_heads, 1, 1)
565
+
566
+ if extra['captioning_mask'] is not None:
567
+ bs,nq,wh = attn_mask.shape
568
+ assert bs==self.num_heads, "Only support single image referring captioning."
569
+ cap_mask = extra['captioning_mask']
570
+ attn_mask = attn_mask.reshape(bs,nq,size_list[i%3][0],size_list[i%3][1])
571
+ cap_mask = F.interpolate(cap_mask[None,].float(), size_list[i%3], mode='nearest').bool()[0,0]
572
+ attn_mask[:,self.num_queries:, cap_mask] = True
573
+ attn_mask = attn_mask.reshape(bs,nq,wh)
574
+
575
+ # attention: cross-attention first
576
+ output, avg_attn = self.transformer_cross_attention_layers[i](
577
+ output, src[level_index],
578
+ memory_mask=attn_mask,
579
+ memory_key_padding_mask=None, # here we do not apply masking on padded region
580
+ pos=pos[level_index], query_pos=query_embed
581
+ )
582
+
583
+ output = self.transformer_self_attention_layers[i](
584
+ output, tgt_mask=self_tgt_mask,
585
+ tgt_key_padding_mask=None,
586
+ query_pos=query_embed
587
+ )
588
+
589
+ # FFN
590
+ output = self.transformer_ffn_layers[i](
591
+ output
592
+ )
593
+
594
+ results = self.forward_prediction_heads(output, mask_features, attn_mask_target_size=size_list[(i + 1) % self.num_feature_levels], layer_id=i, task=task)
595
+ attn_mask = results["attn_mask"]
596
+
597
+ pred_captions_gen = results['outputs_captionting']
598
+ # pred_captions_gen = (pred_captions_gen / pred_captions_gen.norm(dim=-1, keepdim=True) + 1e-7)
599
+ pred_captions_gen = pred_captions_gen @ token_embs.t()
600
+ caping_lang_token[:,cap_idx+1] = pred_captions_gen[:,cap_idx].max(-1)[1]
601
+
602
+ texts = self.lang_encoder.tokenizer.batch_decode(caping_lang_token, skip_special_tokens=False)
603
+ texts_new = []
604
+
605
+ for x in texts:
606
+ x = x.split('<|endoftext|>')[0]
607
+ x = x.replace('<|endoftext|>','')
608
+ x = x.replace('<|startoftext|>','')
609
+ x = x.strip()
610
+ texts_new.append(x)
611
+
612
+ out = {'pred_captionings': caping_lang_token,
613
+ 'pred_texts': texts_new}
614
+ return out
615
+
616
+
617
+ def forward_prediction_heads(self, output, mask_features, attn_mask_target_size, layer_id=-1, task='seg'):
618
+ decoder_output = self.decoder_norm(output)
619
+ decoder_output = decoder_output.transpose(0, 1)
620
+
621
+ # extract image captioning token from decoder output.
622
+ if self.task_switch['captioning'] and (task == 'vlp' or task == 'captioning_infer'):
623
+ outputs_captionting = decoder_output[:,self.num_queries:] @ self.caping_embed
624
+ else:
625
+ outputs_captionting = None
626
+
627
+ # recompute class token output.
628
+ norm_decoder_output = decoder_output / (decoder_output.norm(dim=-1, keepdim=True) + 1e-7)
629
+ obj_token = norm_decoder_output[:,:self.num_queries-1]
630
+ cls_token = norm_decoder_output[:,self.num_queries-1:self.num_queries]
631
+
632
+ sim = (cls_token @ obj_token.transpose(1,2)).softmax(-1)[:,0,:,None] # TODO include class token.
633
+ cls_token = (sim * decoder_output[:,:self.num_queries-1]).sum(dim=1, keepdim=True)
634
+
635
+ if (((self.training and task == 'seg') or (task == 'grounding_eval')) and self.task_switch['grounding']) \
636
+ or ((self.training and task == 'openimage') and self.task_switch['openimage']['grounding']):
637
+ decoder_output = torch.cat((decoder_output[:,:self.num_queries-1], cls_token, decoder_output[:,self.num_queries:2*self.num_queries-1]), dim=1)
638
+ else:
639
+ decoder_output = torch.cat((decoder_output[:,:self.num_queries-1], cls_token), dim=1)
640
+
641
+ # compute class, mask and bbox.
642
+ class_embed = decoder_output @ self.class_embed
643
+ # HACK do not compute similarity if mask is not on
644
+ outputs_class = self.lang_encoder.compute_similarity(class_embed, fake=(((not self.task_switch['mask']) and self.training) or (task == 'openimage')))
645
+
646
+ if self.task_switch['mask'] or self.task_switch['openimage']['mask']:
647
+ mask_embed = self.mask_embed(decoder_output)
648
+ outputs_mask = torch.einsum("bqc,bchw->bqhw", mask_embed, mask_features)
649
+
650
+ # NOTE: prediction is of higher-resolution
651
+ # [B, Q, H, W] -> [B, Q, H*W] -> [B, h, Q, H*W] -> [B*h, Q, HW]
652
+ attn_mask = F.interpolate(outputs_mask, size=attn_mask_target_size, mode="bilinear", align_corners=False)
653
+
654
+ # must use bool type
655
+ # If a BoolTensor is provided, positions with ``True`` are not allowed to attend while ``False`` values will be unchanged.
656
+ attn_mask = (attn_mask.sigmoid().flatten(2).unsqueeze(1).repeat(1, self.num_heads, 1, 1).flatten(0, 1) < 0.5).bool()
657
+ attn_mask = attn_mask.detach()
658
+
659
+ # NOTE: fill False for cls token (JY)
660
+ attn_mask[:, self.num_queries:self.num_queries+1].fill_(False)
661
+ else:
662
+ outputs_mask = None
663
+ attn_mask = torch.zeros((list(decoder_output.shape[:2]) + [attn_mask_target_size[0]*attn_mask_target_size[1]]), device=decoder_output.device).repeat(self.num_heads, 1, 1).bool()
664
+
665
+ outputs_bbox = [None for i in range(len(decoder_output))]
666
+ if self.task_switch['bbox']:
667
+ outputs_bbox = self.bbox_embed(decoder_output)
668
+
669
+ outputs_caption = None
670
+ if self.task_switch['caption']:
671
+ outputs_caption = class_embed
672
+
673
+
674
+ results = {
675
+ "outputs_class": outputs_class,
676
+ "outputs_mask": outputs_mask,
677
+ "outputs_bbox": outputs_bbox,
678
+ "attn_mask": attn_mask,
679
+ "outputs_caption": outputs_caption,
680
+ "outputs_captionting": outputs_captionting,
681
+ }
682
+ return results
683
+
684
+ @torch.jit.unused
685
+ def _set_aux_loss(self, outputs_class, outputs_seg_masks, outputs_boxes, outputs_captions):
686
+ # this is a workaround to make torchscript happy, as torchscript
687
+ # doesn't support dictionary with non-homogeneous values, such
688
+ # as a dict having both a Tensor and a list.
689
+ if self.mask_classification:
690
+ return [
691
+ {"pred_logits": a, "pred_masks": b, "pred_boxes": c, "pred_captions": d}
692
+ for a, b, c, d in zip(outputs_class[:-1], outputs_seg_masks[:-1], outputs_boxes[:-1], outputs_captions[:-1])
693
+ ]
694
+ else:
695
+ return [{"pred_masks": b} for b in outputs_seg_masks[:-1]]
696
+
697
+
698
+ @register_decoder
699
+ def get_masked_transformer_decoder(cfg, in_channels, lang_encoder, mask_classification, extra):
700
+ return MultiScaleMaskedTransformerDecoder(cfg, in_channels, lang_encoder, mask_classification, extra)